From 2292557fd65f4234e83f2f119975e08d04a2a1a8 Mon Sep 17 00:00:00 2001 From: Snider Date: Sat, 21 Feb 2026 19:14:04 +0000 Subject: [PATCH] chore: vendor MLX C headers for Go module consumers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit dist/include/ contains the MLX and MLX-C headers needed for CGo compilation. Without these, go-mlx cannot be used as a module dependency (headers not found in module cache). Libraries (dylib/metallib) are still gitignored — users build those locally via cmake. Co-Authored-By: Virgil --- .gitignore | 5 +- .../metal_cpp/Foundation/Foundation.hpp | 47 + dist/include/metal_cpp/Foundation/NSArray.hpp | 124 + .../Foundation/NSAutoreleasePool.hpp | 83 + .../include/metal_cpp/Foundation/NSBundle.hpp | 374 ++ dist/include/metal_cpp/Foundation/NSData.hpp | 54 + dist/include/metal_cpp/Foundation/NSDate.hpp | 53 + .../metal_cpp/Foundation/NSDefines.hpp | 45 + .../metal_cpp/Foundation/NSDictionary.hpp | 128 + .../metal_cpp/Foundation/NSEnumerator.hpp | 78 + dist/include/metal_cpp/Foundation/NSError.hpp | 173 + dist/include/metal_cpp/Foundation/NSLock.hpp | 118 + .../metal_cpp/Foundation/NSNotification.hpp | 110 + .../include/metal_cpp/Foundation/NSNumber.hpp | 501 +++ .../metal_cpp/Foundation/NSObjCRuntime.hpp | 43 + .../include/metal_cpp/Foundation/NSObject.hpp | 302 ++ .../metal_cpp/Foundation/NSPrivate.hpp | 531 +++ .../metal_cpp/Foundation/NSProcessInfo.hpp | 386 ++ dist/include/metal_cpp/Foundation/NSRange.hpp | 83 + dist/include/metal_cpp/Foundation/NSSet.hpp | 87 + .../metal_cpp/Foundation/NSSharedPtr.hpp | 310 ++ .../include/metal_cpp/Foundation/NSString.hpp | 255 ++ dist/include/metal_cpp/Foundation/NSTypes.hpp | 51 + dist/include/metal_cpp/Foundation/NSURL.hpp | 90 + dist/include/metal_cpp/LICENSE.txt | 202 + .../Metal/MTL4AccelerationStructure.hpp | 1395 +++++++ dist/include/metal_cpp/Metal/MTL4Archive.hpp | 93 + .../metal_cpp/Metal/MTL4ArgumentTable.hpp | 187 + .../metal_cpp/Metal/MTL4BinaryFunction.hpp | 50 + .../Metal/MTL4BinaryFunctionDescriptor.hpp | 97 + .../metal_cpp/Metal/MTL4CommandAllocator.hpp | 100 + .../metal_cpp/Metal/MTL4CommandBuffer.hpp | 193 + .../metal_cpp/Metal/MTL4CommandEncoder.hpp | 134 + .../metal_cpp/Metal/MTL4CommandQueue.hpp | 283 ++ .../metal_cpp/Metal/MTL4CommitFeedback.hpp | 62 + dist/include/metal_cpp/Metal/MTL4Compiler.hpp | 345 ++ .../metal_cpp/Metal/MTL4CompilerTask.hpp | 63 + .../Metal/MTL4ComputeCommandEncoder.hpp | 300 ++ .../metal_cpp/Metal/MTL4ComputePipeline.hpp | 158 + dist/include/metal_cpp/Metal/MTL4Counters.hpp | 138 + .../Metal/MTL4FunctionDescriptor.hpp | 49 + .../metal_cpp/Metal/MTL4LibraryDescriptor.hpp | 98 + .../Metal/MTL4LibraryFunctionDescriptor.hpp | 86 + .../metal_cpp/Metal/MTL4LinkingDescriptor.hpp | 204 + .../MTL4MachineLearningCommandEncoder.hpp | 66 + .../Metal/MTL4MachineLearningPipeline.hpp | 172 + .../Metal/MTL4MeshRenderPipeline.hpp | 413 ++ .../Metal/MTL4PipelineDataSetSerializer.hpp | 85 + .../metal_cpp/Metal/MTL4PipelineState.hpp | 150 + .../Metal/MTL4RenderCommandEncoder.hpp | 340 ++ .../metal_cpp/Metal/MTL4RenderPass.hpp | 280 ++ .../metal_cpp/Metal/MTL4RenderPipeline.hpp | 587 +++ .../MTL4SpecializedFunctionDescriptor.hpp | 100 + .../Metal/MTL4StitchedFunctionDescriptor.hpp | 86 + .../Metal/MTL4TileRenderPipeline.hpp | 173 + .../Metal/MTLAccelerationStructure.hpp | 1887 +++++++++ ...MTLAccelerationStructureCommandEncoder.hpp | 260 ++ .../Metal/MTLAccelerationStructureTypes.hpp | 292 ++ .../include/metal_cpp/Metal/MTLAllocation.hpp | 40 + dist/include/metal_cpp/Metal/MTLArgument.hpp | 787 ++++ .../metal_cpp/Metal/MTLArgumentEncoder.hpp | 235 ++ .../metal_cpp/Metal/MTLBinaryArchive.hpp | 152 + .../metal_cpp/Metal/MTLBlitCommandEncoder.hpp | 226 ++ dist/include/metal_cpp/Metal/MTLBlitPass.hpp | 154 + dist/include/metal_cpp/Metal/MTLBuffer.hpp | 119 + .../metal_cpp/Metal/MTLCaptureManager.hpp | 217 + .../metal_cpp/Metal/MTLCaptureScope.hpp | 91 + .../metal_cpp/Metal/MTLCommandBuffer.hpp | 464 +++ .../metal_cpp/Metal/MTLCommandEncoder.hpp | 117 + .../metal_cpp/Metal/MTLCommandQueue.hpp | 158 + .../Metal/MTLComputeCommandEncoder.hpp | 324 ++ .../metal_cpp/Metal/MTLComputePass.hpp | 169 + .../metal_cpp/Metal/MTLComputePipeline.hpp | 439 ++ dist/include/metal_cpp/Metal/MTLCounters.hpp | 243 ++ dist/include/metal_cpp/Metal/MTLDataType.hpp | 129 + dist/include/metal_cpp/Metal/MTLDefines.hpp | 41 + .../metal_cpp/Metal/MTLDepthStencil.hpp | 277 ++ dist/include/metal_cpp/Metal/MTLDevice.hpp | 1493 +++++++ dist/include/metal_cpp/Metal/MTLDrawable.hpp | 90 + .../metal_cpp/Metal/MTLDynamicLibrary.hpp | 78 + dist/include/metal_cpp/Metal/MTLEvent.hpp | 170 + dist/include/metal_cpp/Metal/MTLFence.hpp | 55 + .../Metal/MTLFunctionConstantValues.hpp | 76 + .../metal_cpp/Metal/MTLFunctionDescriptor.hpp | 153 + .../metal_cpp/Metal/MTLFunctionHandle.hpp | 65 + .../metal_cpp/Metal/MTLFunctionLog.hpp | 101 + .../metal_cpp/Metal/MTLFunctionStitching.hpp | 319 ++ .../include/metal_cpp/Metal/MTLGPUAddress.hpp | 36 + .../metal_cpp/Metal/MTLHeaderBridge.hpp | 3120 ++++++++++++++ dist/include/metal_cpp/Metal/MTLHeap.hpp | 318 ++ .../metal_cpp/Metal/MTLIOCommandBuffer.hpp | 182 + .../metal_cpp/Metal/MTLIOCommandQueue.hpp | 211 + .../metal_cpp/Metal/MTLIOCompressor.hpp | 94 + .../Metal/MTLIndirectCommandBuffer.hpp | 376 ++ .../Metal/MTLIndirectCommandEncoder.hpp | 272 ++ .../Metal/MTLIntersectionFunctionTable.hpp | 173 + dist/include/metal_cpp/Metal/MTLLibrary.hpp | 786 ++++ .../metal_cpp/Metal/MTLLinkedFunctions.hpp | 110 + dist/include/metal_cpp/Metal/MTLLogState.hpp | 111 + .../Metal/MTLParallelRenderCommandEncoder.hpp | 83 + dist/include/metal_cpp/Metal/MTLPipeline.hpp | 104 + .../metal_cpp/Metal/MTLPixelFormat.hpp | 173 + dist/include/metal_cpp/Metal/MTLPrivate.hpp | 156 + .../metal_cpp/Metal/MTLRasterizationRate.hpp | 337 ++ .../Metal/MTLRenderCommandEncoder.hpp | 1019 +++++ .../include/metal_cpp/Metal/MTLRenderPass.hpp | 792 ++++ .../metal_cpp/Metal/MTLRenderPipeline.hpp | 1876 +++++++++ .../metal_cpp/Metal/MTLResidencySet.hpp | 178 + dist/include/metal_cpp/Metal/MTLResource.hpp | 190 + .../Metal/MTLResourceStateCommandEncoder.hpp | 98 + .../metal_cpp/Metal/MTLResourceStatePass.hpp | 154 + .../metal_cpp/Metal/MTLResourceViewPool.hpp | 118 + dist/include/metal_cpp/Metal/MTLSampler.hpp | 345 ++ .../Metal/MTLStageInputOutputDescriptor.hpp | 356 ++ dist/include/metal_cpp/Metal/MTLTensor.hpp | 297 ++ dist/include/metal_cpp/Metal/MTLTexture.hpp | 803 ++++ .../metal_cpp/Metal/MTLTextureViewPool.hpp | 59 + dist/include/metal_cpp/Metal/MTLTypes.hpp | 164 + dist/include/metal_cpp/Metal/MTLVersion.hpp | 32 + .../metal_cpp/Metal/MTLVertexDescriptor.hpp | 326 ++ .../Metal/MTLVisibleFunctionTable.hpp | 96 + dist/include/metal_cpp/Metal/Metal.hpp | 120 + .../MetalFX/MTL4FXFrameInterpolator.hpp | 47 + .../metal_cpp/MetalFX/MTL4FXSpatialScaler.hpp | 49 + .../MetalFX/MTL4FXTemporalDenoisedScaler.hpp | 49 + .../MetalFX/MTL4FXTemporalScaler.hpp | 49 + .../metal_cpp/MetalFX/MTLFXDefines.hpp | 41 + .../MetalFX/MTLFXFrameInterpolator.hpp | 719 ++++ .../metal_cpp/MetalFX/MTLFXPrivate.hpp | 482 +++ .../metal_cpp/MetalFX/MTLFXSpatialScaler.hpp | 397 ++ .../MetalFX/MTLFXTemporalDenoisedScaler.hpp | 1208 ++++++ .../metal_cpp/MetalFX/MTLFXTemporalScaler.hpp | 803 ++++ dist/include/metal_cpp/MetalFX/MetalFX.hpp | 35 + .../metal_cpp/QuartzCore/CADefines.hpp | 41 + .../metal_cpp/QuartzCore/CAMetalDrawable.hpp | 57 + .../metal_cpp/QuartzCore/CAMetalLayer.hpp | 216 + .../metal_cpp/QuartzCore/CAPrivate.hpp | 150 + .../metal_cpp/QuartzCore/QuartzCore.hpp | 28 + dist/include/metal_cpp/README.md | 313 ++ .../SingleHeader/MakeSingleHeader.py | 271 ++ dist/include/mlx/3rdparty/pocketfft.h | 3581 +++++++++++++++++ dist/include/mlx/allocator.h | 73 + dist/include/mlx/array.h | 645 +++ dist/include/mlx/backend/common/binary.h | 97 + .../include/mlx/backend/common/broadcasting.h | 11 + .../include/mlx/backend/common/buffer_cache.h | 157 + dist/include/mlx/backend/common/compiled.h | 77 + dist/include/mlx/backend/common/copy.h | 50 + dist/include/mlx/backend/common/hadamard.h | 109 + dist/include/mlx/backend/common/matmul.h | 67 + dist/include/mlx/backend/common/reduce.h | 59 + dist/include/mlx/backend/common/slicing.h | 20 + dist/include/mlx/backend/common/ternary.h | 85 + dist/include/mlx/backend/common/unary.h | 29 + dist/include/mlx/backend/common/utils.h | 205 + dist/include/mlx/backend/cpu/arange.h | 28 + dist/include/mlx/backend/cpu/available.h | 9 + dist/include/mlx/backend/cpu/binary.h | 517 +++ dist/include/mlx/backend/cpu/binary_ops.h | 98 + dist/include/mlx/backend/cpu/binary_two.h | 166 + .../mlx/backend/cpu/compiled_preamble.h | 12 + dist/include/mlx/backend/cpu/copy.h | 36 + dist/include/mlx/backend/cpu/encoder.h | 67 + dist/include/mlx/backend/cpu/eval.h | 12 + dist/include/mlx/backend/cpu/gemm.h | 26 + .../include/mlx/backend/cpu/gemms/simd_gemm.h | 139 + dist/include/mlx/backend/cpu/jit_compiler.h | 20 + dist/include/mlx/backend/cpu/lapack.h | 80 + .../backend/cpu/simd/accelerate_fp16_simd.h | 56 + .../mlx/backend/cpu/simd/accelerate_simd.h | 329 ++ dist/include/mlx/backend/cpu/simd/base_simd.h | 295 ++ dist/include/mlx/backend/cpu/simd/math.h | 193 + .../mlx/backend/cpu/simd/neon_fp16_simd.h | 212 + dist/include/mlx/backend/cpu/simd/simd.h | 4 + dist/include/mlx/backend/cpu/simd/type.h | 11 + dist/include/mlx/backend/cpu/slicing.h | 21 + dist/include/mlx/backend/cpu/ternary.h | 154 + dist/include/mlx/backend/cpu/threefry.h | 21 + dist/include/mlx/backend/cpu/unary.h | 281 ++ dist/include/mlx/backend/cpu/unary_ops.h | 180 + dist/include/mlx/backend/cuda/allocator.h | 89 + dist/include/mlx/backend/cuda/conv/conv.h | 126 + dist/include/mlx/backend/cuda/cublas_utils.h | 96 + dist/include/mlx/backend/cuda/cuda.h | 10 + dist/include/mlx/backend/cuda/cuda_utils.h | 89 + dist/include/mlx/backend/cuda/cudnn_utils.h | 171 + dist/include/mlx/backend/cuda/device.h | 189 + dist/include/mlx/backend/cuda/device/config.h | 12 + dist/include/mlx/backend/cuda/event.h | 78 + .../mlx/backend/cuda/gemms/cublas_gemm.h | 114 + dist/include/mlx/backend/cuda/gemms/gemv.h | 24 + dist/include/mlx/backend/cuda/jit_module.h | 119 + dist/include/mlx/backend/cuda/lru_cache.h | 189 + .../mlx/backend/cuda/quantized/cublas_qqmm.h | 88 + .../mlx/backend/cuda/quantized/cuda_fp4.h | 83 + .../mlx/backend/cuda/quantized/qqmm_utils.h | 30 + .../mlx/backend/cuda/quantized/quantized.h | 45 + dist/include/mlx/backend/cuda/utils.h | 46 + dist/include/mlx/backend/cuda/worker.h | 55 + dist/include/mlx/backend/gpu/available.h | 9 + dist/include/mlx/backend/gpu/copy.h | 57 + dist/include/mlx/backend/gpu/eval.h | 18 + dist/include/mlx/backend/gpu/slicing.h | 36 + dist/include/mlx/backend/metal/allocator.h | 79 + dist/include/mlx/backend/metal/binary.h | 33 + dist/include/mlx/backend/metal/device.h | 283 ++ dist/include/mlx/backend/metal/jit/includes.h | 57 + dist/include/mlx/backend/metal/jit/indexing.h | 76 + .../mlx/backend/metal/kernels/arange.h | 9 + .../mlx/backend/metal/kernels/atomic.h | 345 ++ dist/include/mlx/backend/metal/kernels/bf16.h | 16 + .../mlx/backend/metal/kernels/bf16_math.h | 380 ++ .../mlx/backend/metal/kernels/binary.h | 199 + .../mlx/backend/metal/kernels/binary_ops.h | 326 ++ .../mlx/backend/metal/kernels/binary_two.h | 244 ++ .../include/mlx/backend/metal/kernels/cexpf.h | 134 + .../mlx/backend/metal/kernels/complex.h | 173 + dist/include/mlx/backend/metal/kernels/copy.h | 276 ++ .../mlx/backend/metal/kernels/defines.h | 24 + dist/include/mlx/backend/metal/kernels/erf.h | 69 + .../mlx/backend/metal/kernels/expm1f.h | 90 + dist/include/mlx/backend/metal/kernels/fft.h | 486 +++ .../mlx/backend/metal/kernels/fft/radix.h | 328 ++ .../mlx/backend/metal/kernels/fft/readwrite.h | 624 +++ dist/include/mlx/backend/metal/kernels/fp4.h | 59 + dist/include/mlx/backend/metal/kernels/fp8.h | 82 + .../mlx/backend/metal/kernels/fp_quantized.h | 1804 +++++++++ .../backend/metal/kernels/fp_quantized_nax.h | 1059 +++++ .../mlx/backend/metal/kernels/gemv_masked.h | 827 ++++ .../mlx/backend/metal/kernels/hadamard.h | 182 + .../backend/metal/kernels/indexing/gather.h | 51 + .../metal/kernels/indexing/gather_axis.h | 44 + .../metal/kernels/indexing/gather_front.h | 24 + .../backend/metal/kernels/indexing/indexing.h | 23 + .../metal/kernels/indexing/masked_scatter.h | 38 + .../backend/metal/kernels/indexing/scatter.h | 59 + .../metal/kernels/indexing/scatter_axis.h | 52 + .../mlx/backend/metal/kernels/logsumexp.h | 140 + .../mlx/backend/metal/kernels/quantized.h | 2502 ++++++++++++ .../mlx/backend/metal/kernels/quantized_nax.h | 1705 ++++++++ .../backend/metal/kernels/quantized_utils.h | 90 + .../mlx/backend/metal/kernels/reduce.h | 5 + .../mlx/backend/metal/kernels/reduce_utils.h | 6 + .../mlx/backend/metal/kernels/reduction/ops.h | 275 ++ .../metal/kernels/reduction/reduce_all.h | 66 + .../metal/kernels/reduction/reduce_col.h | 398 ++ .../metal/kernels/reduction/reduce_init.h | 8 + .../metal/kernels/reduction/reduce_row.h | 369 ++ dist/include/mlx/backend/metal/kernels/scan.h | 514 +++ .../mlx/backend/metal/kernels/sdpa_vector.h | 415 ++ .../mlx/backend/metal/kernels/softmax.h | 190 + dist/include/mlx/backend/metal/kernels/sort.h | 715 ++++ .../backend/metal/kernels/steel/attn/attn.h | 296 ++ .../steel/attn/kernels/steel_attention.h | 476 +++ .../steel/attn/kernels/steel_attention_nax.h | 481 +++ .../backend/metal/kernels/steel/attn/loader.h | 264 ++ .../backend/metal/kernels/steel/attn/mma.h | 750 ++++ .../backend/metal/kernels/steel/attn/nax.h | 1076 +++++ .../backend/metal/kernels/steel/attn/params.h | 44 + .../metal/kernels/steel/attn/transforms.h | 71 + .../backend/metal/kernels/steel/conv/conv.h | 13 + .../kernels/steel/conv/kernels/steel_conv.h | 176 + .../steel/conv/kernels/steel_conv_general.h | 225 ++ .../backend/metal/kernels/steel/conv/loader.h | 6 + .../steel/conv/loaders/loader_channel_l.h | 451 +++ .../steel/conv/loaders/loader_channel_n.h | 319 ++ .../steel/conv/loaders/loader_general.h | 381 ++ .../backend/metal/kernels/steel/conv/params.h | 62 + .../mlx/backend/metal/kernels/steel/defines.h | 7 + .../backend/metal/kernels/steel/gemm/gemm.h | 295 ++ .../metal/kernels/steel/gemm/gemm_nax.h | 156 + .../steel/gemm/kernels/steel_gemm_fused.h | 346 ++ .../steel/gemm/kernels/steel_gemm_fused_nax.h | 207 + .../steel/gemm/kernels/steel_gemm_gather.h | 459 +++ .../gemm/kernels/steel_gemm_gather_nax.h | 132 + .../steel/gemm/kernels/steel_gemm_masked.h | 719 ++++ .../steel/gemm/kernels/steel_gemm_segmented.h | 266 ++ .../steel/gemm/kernels/steel_gemm_splitk.h | 227 ++ .../backend/metal/kernels/steel/gemm/loader.h | 137 + .../backend/metal/kernels/steel/gemm/mma.h | 1146 ++++++ .../backend/metal/kernels/steel/gemm/nax.h | 1084 +++++ .../backend/metal/kernels/steel/gemm/params.h | 64 + .../metal/kernels/steel/gemm/transforms.h | 72 + .../mlx/backend/metal/kernels/steel/utils.h | 42 + .../kernels/steel/utils/integral_constant.h | 134 + .../metal/kernels/steel/utils/type_traits.h | 55 + .../mlx/backend/metal/kernels/ternary.h | 145 + .../mlx/backend/metal/kernels/ternary_ops.h | 10 + .../include/mlx/backend/metal/kernels/unary.h | 63 + .../mlx/backend/metal/kernels/unary_ops.h | 454 +++ .../include/mlx/backend/metal/kernels/utils.h | 444 ++ dist/include/mlx/backend/metal/matmul.h | 144 + dist/include/mlx/backend/metal/metal.h | 22 + dist/include/mlx/backend/metal/reduce.h | 41 + dist/include/mlx/backend/metal/resident.h | 32 + dist/include/mlx/backend/metal/scan.h | 17 + dist/include/mlx/backend/metal/ternary.h | 21 + dist/include/mlx/backend/metal/unary.h | 21 + dist/include/mlx/backend/metal/utils.h | 84 + .../include/mlx/backend/no_gpu/apple_memory.h | 16 + .../include/mlx/backend/no_gpu/linux_memory.h | 22 + dist/include/mlx/c/array.h | 379 ++ dist/include/mlx/c/closure.h | 197 + dist/include/mlx/c/compile.h | 55 + dist/include/mlx/c/device.h | 80 + dist/include/mlx/c/distributed.h | 81 + dist/include/mlx/c/distributed_group.h | 58 + dist/include/mlx/c/error.h | 41 + dist/include/mlx/c/export.h | 75 + dist/include/mlx/c/fast.h | 205 + dist/include/mlx/c/fft.h | 136 + dist/include/mlx/c/half.h | 26 + dist/include/mlx/c/io.h | 61 + dist/include/mlx/c/io_types.h | 104 + dist/include/mlx/c/linalg.h | 126 + dist/include/mlx/c/map.h | 149 + dist/include/mlx/c/memory.h | 45 + dist/include/mlx/c/metal.h | 48 + dist/include/mlx/c/mlx.h | 33 + dist/include/mlx/c/ops.h | 1233 ++++++ dist/include/mlx/c/optional.h | 51 + dist/include/mlx/c/random.h | 164 + dist/include/mlx/c/stream.h | 88 + dist/include/mlx/c/string.h | 55 + dist/include/mlx/c/transforms.h | 66 + dist/include/mlx/c/transforms_impl.h | 52 + dist/include/mlx/c/vector.h | 133 + dist/include/mlx/c/version.h | 18 + dist/include/mlx/compile.h | 44 + dist/include/mlx/compile_impl.h | 69 + dist/include/mlx/device.h | 31 + dist/include/mlx/distributed/distributed.h | 60 + .../mlx/distributed/distributed_impl.h | 59 + dist/include/mlx/distributed/jaccl/jaccl.h | 12 + dist/include/mlx/distributed/mpi/mpi.h | 12 + .../mlx/distributed/mpi/mpi_declarations.h | 28 + dist/include/mlx/distributed/nccl/nccl.h | 12 + dist/include/mlx/distributed/ops.h | 56 + dist/include/mlx/distributed/primitives.h | 156 + dist/include/mlx/distributed/reduction_ops.h | 38 + dist/include/mlx/distributed/ring/ring.h | 12 + dist/include/mlx/distributed/utils.h | 67 + dist/include/mlx/dtype.h | 115 + dist/include/mlx/dtype_utils.h | 119 + dist/include/mlx/einsum.h | 22 + dist/include/mlx/event.h | 58 + dist/include/mlx/export.h | 136 + dist/include/mlx/export_impl.h | 98 + dist/include/mlx/fast.h | 102 + dist/include/mlx/fast_primitives.h | 427 ++ dist/include/mlx/fence.h | 39 + dist/include/mlx/fft.h | 167 + dist/include/mlx/graph_utils.h | 66 + dist/include/mlx/io.h | 61 + dist/include/mlx/io/gguf.h | 20 + dist/include/mlx/io/load.h | 175 + dist/include/mlx/linalg.h | 111 + dist/include/mlx/memory.h | 78 + dist/include/mlx/mlx.h | 25 + dist/include/mlx/ops.h | 1627 ++++++++ dist/include/mlx/primitives.h | 2524 ++++++++++++ dist/include/mlx/random.h | 282 ++ dist/include/mlx/scheduler.h | 188 + dist/include/mlx/small_vector.h | 540 +++ dist/include/mlx/stream.h | 41 + dist/include/mlx/threadpool.h | 133 + dist/include/mlx/transforms.h | 229 ++ dist/include/mlx/transforms_impl.h | 86 + dist/include/mlx/types/bf16.h | 187 + dist/include/mlx/types/complex.h | 113 + dist/include/mlx/types/fp16.h | 234 ++ dist/include/mlx/types/half_types.h | 58 + dist/include/mlx/types/limits.h | 70 + dist/include/mlx/utils.h | 175 + dist/include/mlx/version.h | 20 + 375 files changed, 89633 insertions(+), 2 deletions(-) create mode 100644 dist/include/metal_cpp/Foundation/Foundation.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSArray.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSAutoreleasePool.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSBundle.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSData.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSDate.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSDefines.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSDictionary.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSEnumerator.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSError.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSLock.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSNotification.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSNumber.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSObjCRuntime.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSObject.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSPrivate.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSProcessInfo.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSRange.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSSet.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSSharedPtr.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSString.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSTypes.hpp create mode 100644 dist/include/metal_cpp/Foundation/NSURL.hpp create mode 100644 dist/include/metal_cpp/LICENSE.txt create mode 100644 dist/include/metal_cpp/Metal/MTL4AccelerationStructure.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4Archive.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4ArgumentTable.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4BinaryFunction.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4BinaryFunctionDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4CommandAllocator.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4CommandBuffer.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4CommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4CommandQueue.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4CommitFeedback.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4Compiler.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4CompilerTask.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4ComputeCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4ComputePipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4Counters.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4FunctionDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4LibraryDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4LibraryFunctionDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4LinkingDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4MachineLearningCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4MachineLearningPipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4MeshRenderPipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4PipelineDataSetSerializer.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4PipelineState.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4RenderCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4RenderPass.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4RenderPipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4StitchedFunctionDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTL4TileRenderPipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLAccelerationStructure.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLAccelerationStructureTypes.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLAllocation.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLArgument.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLArgumentEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLBinaryArchive.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLBlitCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLBlitPass.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLBuffer.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLCaptureManager.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLCaptureScope.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLCommandBuffer.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLCommandQueue.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLComputeCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLComputePass.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLComputePipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLCounters.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLDataType.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLDefines.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLDepthStencil.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLDevice.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLDrawable.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLDynamicLibrary.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLEvent.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLFence.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLFunctionConstantValues.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLFunctionDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLFunctionHandle.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLFunctionLog.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLFunctionStitching.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLGPUAddress.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLHeaderBridge.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLHeap.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLIOCommandBuffer.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLIOCommandQueue.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLIOCompressor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLIndirectCommandBuffer.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLIndirectCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLIntersectionFunctionTable.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLLibrary.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLLinkedFunctions.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLLogState.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLParallelRenderCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLPipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLPixelFormat.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLPrivate.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLRasterizationRate.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLRenderCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLRenderPass.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLRenderPipeline.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLResidencySet.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLResource.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLResourceStateCommandEncoder.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLResourceStatePass.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLResourceViewPool.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLSampler.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLStageInputOutputDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLTensor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLTexture.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLTextureViewPool.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLTypes.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLVersion.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLVertexDescriptor.hpp create mode 100644 dist/include/metal_cpp/Metal/MTLVisibleFunctionTable.hpp create mode 100644 dist/include/metal_cpp/Metal/Metal.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTL4FXFrameInterpolator.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTL4FXSpatialScaler.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTL4FXTemporalScaler.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTLFXDefines.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTLFXFrameInterpolator.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTLFXPrivate.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTLFXSpatialScaler.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MTLFXTemporalScaler.hpp create mode 100644 dist/include/metal_cpp/MetalFX/MetalFX.hpp create mode 100644 dist/include/metal_cpp/QuartzCore/CADefines.hpp create mode 100644 dist/include/metal_cpp/QuartzCore/CAMetalDrawable.hpp create mode 100644 dist/include/metal_cpp/QuartzCore/CAMetalLayer.hpp create mode 100644 dist/include/metal_cpp/QuartzCore/CAPrivate.hpp create mode 100644 dist/include/metal_cpp/QuartzCore/QuartzCore.hpp create mode 100644 dist/include/metal_cpp/README.md create mode 100644 dist/include/metal_cpp/SingleHeader/MakeSingleHeader.py create mode 100644 dist/include/mlx/3rdparty/pocketfft.h create mode 100644 dist/include/mlx/allocator.h create mode 100644 dist/include/mlx/array.h create mode 100644 dist/include/mlx/backend/common/binary.h create mode 100644 dist/include/mlx/backend/common/broadcasting.h create mode 100644 dist/include/mlx/backend/common/buffer_cache.h create mode 100644 dist/include/mlx/backend/common/compiled.h create mode 100644 dist/include/mlx/backend/common/copy.h create mode 100644 dist/include/mlx/backend/common/hadamard.h create mode 100644 dist/include/mlx/backend/common/matmul.h create mode 100644 dist/include/mlx/backend/common/reduce.h create mode 100644 dist/include/mlx/backend/common/slicing.h create mode 100644 dist/include/mlx/backend/common/ternary.h create mode 100644 dist/include/mlx/backend/common/unary.h create mode 100644 dist/include/mlx/backend/common/utils.h create mode 100644 dist/include/mlx/backend/cpu/arange.h create mode 100644 dist/include/mlx/backend/cpu/available.h create mode 100644 dist/include/mlx/backend/cpu/binary.h create mode 100644 dist/include/mlx/backend/cpu/binary_ops.h create mode 100644 dist/include/mlx/backend/cpu/binary_two.h create mode 100644 dist/include/mlx/backend/cpu/compiled_preamble.h create mode 100644 dist/include/mlx/backend/cpu/copy.h create mode 100644 dist/include/mlx/backend/cpu/encoder.h create mode 100644 dist/include/mlx/backend/cpu/eval.h create mode 100644 dist/include/mlx/backend/cpu/gemm.h create mode 100644 dist/include/mlx/backend/cpu/gemms/simd_gemm.h create mode 100644 dist/include/mlx/backend/cpu/jit_compiler.h create mode 100644 dist/include/mlx/backend/cpu/lapack.h create mode 100644 dist/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h create mode 100644 dist/include/mlx/backend/cpu/simd/accelerate_simd.h create mode 100644 dist/include/mlx/backend/cpu/simd/base_simd.h create mode 100644 dist/include/mlx/backend/cpu/simd/math.h create mode 100644 dist/include/mlx/backend/cpu/simd/neon_fp16_simd.h create mode 100644 dist/include/mlx/backend/cpu/simd/simd.h create mode 100644 dist/include/mlx/backend/cpu/simd/type.h create mode 100644 dist/include/mlx/backend/cpu/slicing.h create mode 100644 dist/include/mlx/backend/cpu/ternary.h create mode 100644 dist/include/mlx/backend/cpu/threefry.h create mode 100644 dist/include/mlx/backend/cpu/unary.h create mode 100644 dist/include/mlx/backend/cpu/unary_ops.h create mode 100644 dist/include/mlx/backend/cuda/allocator.h create mode 100644 dist/include/mlx/backend/cuda/conv/conv.h create mode 100644 dist/include/mlx/backend/cuda/cublas_utils.h create mode 100644 dist/include/mlx/backend/cuda/cuda.h create mode 100644 dist/include/mlx/backend/cuda/cuda_utils.h create mode 100644 dist/include/mlx/backend/cuda/cudnn_utils.h create mode 100644 dist/include/mlx/backend/cuda/device.h create mode 100644 dist/include/mlx/backend/cuda/device/config.h create mode 100644 dist/include/mlx/backend/cuda/event.h create mode 100644 dist/include/mlx/backend/cuda/gemms/cublas_gemm.h create mode 100644 dist/include/mlx/backend/cuda/gemms/gemv.h create mode 100644 dist/include/mlx/backend/cuda/jit_module.h create mode 100644 dist/include/mlx/backend/cuda/lru_cache.h create mode 100644 dist/include/mlx/backend/cuda/quantized/cublas_qqmm.h create mode 100644 dist/include/mlx/backend/cuda/quantized/cuda_fp4.h create mode 100644 dist/include/mlx/backend/cuda/quantized/qqmm_utils.h create mode 100644 dist/include/mlx/backend/cuda/quantized/quantized.h create mode 100644 dist/include/mlx/backend/cuda/utils.h create mode 100644 dist/include/mlx/backend/cuda/worker.h create mode 100644 dist/include/mlx/backend/gpu/available.h create mode 100644 dist/include/mlx/backend/gpu/copy.h create mode 100644 dist/include/mlx/backend/gpu/eval.h create mode 100644 dist/include/mlx/backend/gpu/slicing.h create mode 100644 dist/include/mlx/backend/metal/allocator.h create mode 100644 dist/include/mlx/backend/metal/binary.h create mode 100644 dist/include/mlx/backend/metal/device.h create mode 100644 dist/include/mlx/backend/metal/jit/includes.h create mode 100644 dist/include/mlx/backend/metal/jit/indexing.h create mode 100644 dist/include/mlx/backend/metal/kernels/arange.h create mode 100644 dist/include/mlx/backend/metal/kernels/atomic.h create mode 100644 dist/include/mlx/backend/metal/kernels/bf16.h create mode 100644 dist/include/mlx/backend/metal/kernels/bf16_math.h create mode 100644 dist/include/mlx/backend/metal/kernels/binary.h create mode 100644 dist/include/mlx/backend/metal/kernels/binary_ops.h create mode 100644 dist/include/mlx/backend/metal/kernels/binary_two.h create mode 100644 dist/include/mlx/backend/metal/kernels/cexpf.h create mode 100644 dist/include/mlx/backend/metal/kernels/complex.h create mode 100644 dist/include/mlx/backend/metal/kernels/copy.h create mode 100644 dist/include/mlx/backend/metal/kernels/defines.h create mode 100644 dist/include/mlx/backend/metal/kernels/erf.h create mode 100644 dist/include/mlx/backend/metal/kernels/expm1f.h create mode 100644 dist/include/mlx/backend/metal/kernels/fft.h create mode 100644 dist/include/mlx/backend/metal/kernels/fft/radix.h create mode 100644 dist/include/mlx/backend/metal/kernels/fft/readwrite.h create mode 100644 dist/include/mlx/backend/metal/kernels/fp4.h create mode 100644 dist/include/mlx/backend/metal/kernels/fp8.h create mode 100644 dist/include/mlx/backend/metal/kernels/fp_quantized.h create mode 100644 dist/include/mlx/backend/metal/kernels/fp_quantized_nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/gemv_masked.h create mode 100644 dist/include/mlx/backend/metal/kernels/hadamard.h create mode 100644 dist/include/mlx/backend/metal/kernels/indexing/gather.h create mode 100644 dist/include/mlx/backend/metal/kernels/indexing/gather_axis.h create mode 100644 dist/include/mlx/backend/metal/kernels/indexing/gather_front.h create mode 100644 dist/include/mlx/backend/metal/kernels/indexing/indexing.h create mode 100644 dist/include/mlx/backend/metal/kernels/indexing/masked_scatter.h create mode 100644 dist/include/mlx/backend/metal/kernels/indexing/scatter.h create mode 100644 dist/include/mlx/backend/metal/kernels/indexing/scatter_axis.h create mode 100644 dist/include/mlx/backend/metal/kernels/logsumexp.h create mode 100644 dist/include/mlx/backend/metal/kernels/quantized.h create mode 100644 dist/include/mlx/backend/metal/kernels/quantized_nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/quantized_utils.h create mode 100644 dist/include/mlx/backend/metal/kernels/reduce.h create mode 100644 dist/include/mlx/backend/metal/kernels/reduce_utils.h create mode 100644 dist/include/mlx/backend/metal/kernels/reduction/ops.h create mode 100644 dist/include/mlx/backend/metal/kernels/reduction/reduce_all.h create mode 100644 dist/include/mlx/backend/metal/kernels/reduction/reduce_col.h create mode 100644 dist/include/mlx/backend/metal/kernels/reduction/reduce_init.h create mode 100644 dist/include/mlx/backend/metal/kernels/reduction/reduce_row.h create mode 100644 dist/include/mlx/backend/metal/kernels/scan.h create mode 100644 dist/include/mlx/backend/metal/kernels/sdpa_vector.h create mode 100644 dist/include/mlx/backend/metal/kernels/softmax.h create mode 100644 dist/include/mlx/backend/metal/kernels/sort.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/attn.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/loader.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/mma.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/params.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/attn/transforms.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/conv.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/loader.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/conv/params.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/defines.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/gemm.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/loader.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/mma.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/nax.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/params.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/gemm/transforms.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/utils.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h create mode 100644 dist/include/mlx/backend/metal/kernels/steel/utils/type_traits.h create mode 100644 dist/include/mlx/backend/metal/kernels/ternary.h create mode 100644 dist/include/mlx/backend/metal/kernels/ternary_ops.h create mode 100644 dist/include/mlx/backend/metal/kernels/unary.h create mode 100644 dist/include/mlx/backend/metal/kernels/unary_ops.h create mode 100644 dist/include/mlx/backend/metal/kernels/utils.h create mode 100644 dist/include/mlx/backend/metal/matmul.h create mode 100644 dist/include/mlx/backend/metal/metal.h create mode 100644 dist/include/mlx/backend/metal/reduce.h create mode 100644 dist/include/mlx/backend/metal/resident.h create mode 100644 dist/include/mlx/backend/metal/scan.h create mode 100644 dist/include/mlx/backend/metal/ternary.h create mode 100644 dist/include/mlx/backend/metal/unary.h create mode 100644 dist/include/mlx/backend/metal/utils.h create mode 100644 dist/include/mlx/backend/no_gpu/apple_memory.h create mode 100644 dist/include/mlx/backend/no_gpu/linux_memory.h create mode 100644 dist/include/mlx/c/array.h create mode 100644 dist/include/mlx/c/closure.h create mode 100644 dist/include/mlx/c/compile.h create mode 100644 dist/include/mlx/c/device.h create mode 100644 dist/include/mlx/c/distributed.h create mode 100644 dist/include/mlx/c/distributed_group.h create mode 100644 dist/include/mlx/c/error.h create mode 100644 dist/include/mlx/c/export.h create mode 100644 dist/include/mlx/c/fast.h create mode 100644 dist/include/mlx/c/fft.h create mode 100644 dist/include/mlx/c/half.h create mode 100644 dist/include/mlx/c/io.h create mode 100644 dist/include/mlx/c/io_types.h create mode 100644 dist/include/mlx/c/linalg.h create mode 100644 dist/include/mlx/c/map.h create mode 100644 dist/include/mlx/c/memory.h create mode 100644 dist/include/mlx/c/metal.h create mode 100644 dist/include/mlx/c/mlx.h create mode 100644 dist/include/mlx/c/ops.h create mode 100644 dist/include/mlx/c/optional.h create mode 100644 dist/include/mlx/c/random.h create mode 100644 dist/include/mlx/c/stream.h create mode 100644 dist/include/mlx/c/string.h create mode 100644 dist/include/mlx/c/transforms.h create mode 100644 dist/include/mlx/c/transforms_impl.h create mode 100644 dist/include/mlx/c/vector.h create mode 100644 dist/include/mlx/c/version.h create mode 100644 dist/include/mlx/compile.h create mode 100644 dist/include/mlx/compile_impl.h create mode 100644 dist/include/mlx/device.h create mode 100644 dist/include/mlx/distributed/distributed.h create mode 100644 dist/include/mlx/distributed/distributed_impl.h create mode 100644 dist/include/mlx/distributed/jaccl/jaccl.h create mode 100644 dist/include/mlx/distributed/mpi/mpi.h create mode 100644 dist/include/mlx/distributed/mpi/mpi_declarations.h create mode 100644 dist/include/mlx/distributed/nccl/nccl.h create mode 100644 dist/include/mlx/distributed/ops.h create mode 100644 dist/include/mlx/distributed/primitives.h create mode 100644 dist/include/mlx/distributed/reduction_ops.h create mode 100644 dist/include/mlx/distributed/ring/ring.h create mode 100644 dist/include/mlx/distributed/utils.h create mode 100644 dist/include/mlx/dtype.h create mode 100644 dist/include/mlx/dtype_utils.h create mode 100644 dist/include/mlx/einsum.h create mode 100644 dist/include/mlx/event.h create mode 100644 dist/include/mlx/export.h create mode 100644 dist/include/mlx/export_impl.h create mode 100644 dist/include/mlx/fast.h create mode 100644 dist/include/mlx/fast_primitives.h create mode 100644 dist/include/mlx/fence.h create mode 100644 dist/include/mlx/fft.h create mode 100644 dist/include/mlx/graph_utils.h create mode 100644 dist/include/mlx/io.h create mode 100644 dist/include/mlx/io/gguf.h create mode 100644 dist/include/mlx/io/load.h create mode 100644 dist/include/mlx/linalg.h create mode 100644 dist/include/mlx/memory.h create mode 100644 dist/include/mlx/mlx.h create mode 100644 dist/include/mlx/ops.h create mode 100644 dist/include/mlx/primitives.h create mode 100644 dist/include/mlx/random.h create mode 100644 dist/include/mlx/scheduler.h create mode 100644 dist/include/mlx/small_vector.h create mode 100644 dist/include/mlx/stream.h create mode 100644 dist/include/mlx/threadpool.h create mode 100644 dist/include/mlx/transforms.h create mode 100644 dist/include/mlx/transforms_impl.h create mode 100644 dist/include/mlx/types/bf16.h create mode 100644 dist/include/mlx/types/complex.h create mode 100644 dist/include/mlx/types/fp16.h create mode 100644 dist/include/mlx/types/half_types.h create mode 100644 dist/include/mlx/types/limits.h create mode 100644 dist/include/mlx/utils.h create mode 100644 dist/include/mlx/version.h diff --git a/.gitignore b/.gitignore index 6d1fd5a..a86575f 100644 --- a/.gitignore +++ b/.gitignore @@ -10,8 +10,9 @@ CMakeFiles/ cmake_install.cmake Makefile -# CMake install output -dist/ +# CMake install output (keep headers for Go module consumers) +dist/* +!dist/include/ # IDE .idea/ diff --git a/dist/include/metal_cpp/Foundation/Foundation.hpp b/dist/include/metal_cpp/Foundation/Foundation.hpp new file mode 100644 index 0000000..31e8fb3 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/Foundation.hpp @@ -0,0 +1,47 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/Foundation.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSArray.hpp" +#include "NSAutoreleasePool.hpp" +#include "NSBundle.hpp" +#include "NSData.hpp" +#include "NSDate.hpp" +#include "NSDefines.hpp" +#include "NSDictionary.hpp" +#include "NSEnumerator.hpp" +#include "NSError.hpp" +#include "NSLock.hpp" +#include "NSNotification.hpp" +#include "NSNumber.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSProcessInfo.hpp" +#include "NSRange.hpp" +#include "NSSet.hpp" +#include "NSSharedPtr.hpp" +#include "NSString.hpp" +#include "NSTypes.hpp" +#include "NSURL.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSArray.hpp b/dist/include/metal_cpp/Foundation/NSArray.hpp new file mode 100644 index 0000000..ea04d1e --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSArray.hpp @@ -0,0 +1,124 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSArray.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSTypes.hpp" +#include "NSEnumerator.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Array : public Copying +{ +public: + static Array* array(); + static Array* array(const Object* pObject); + static Array* array(const Object* const* pObjects, UInteger count); + + static Array* alloc(); + + Array* init(); + Array* init(const Object* const* pObjects, UInteger count); + Array* init(const class Coder* pCoder); + + template + _Object* object(UInteger index) const; + UInteger count() const; + Enumerator* objectEnumerator() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::array() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSArray), _NS_PRIVATE_SEL(array)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::array(const Object* pObject) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSArray), _NS_PRIVATE_SEL(arrayWithObject_), pObject); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::array(const Object* const* pObjects, UInteger count) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSArray), _NS_PRIVATE_SEL(arrayWithObjects_count_), pObjects, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSArray)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::init(const Object* const* pObjects, UInteger count) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithObjects_count_), pObjects, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Array::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Array::count() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(count)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Object* NS::Array::object(UInteger index) const +{ + return Object::sendMessage<_Object*>(this, _NS_PRIVATE_SEL(objectAtIndex_), index); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Enumerator* NS::Array::objectEnumerator() const +{ + return NS::Object::sendMessage*>(this, _NS_PRIVATE_SEL(objectEnumerator)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSAutoreleasePool.hpp b/dist/include/metal_cpp/Foundation/NSAutoreleasePool.hpp new file mode 100644 index 0000000..6d01a46 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSAutoreleasePool.hpp @@ -0,0 +1,83 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSAutoreleasePool.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class AutoreleasePool : public Object +{ +public: + static AutoreleasePool* alloc(); + AutoreleasePool* init(); + + void drain(); + + void addObject(Object* pObject); + + static void showPools(); +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::AutoreleasePool* NS::AutoreleasePool::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSAutoreleasePool)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::AutoreleasePool* NS::AutoreleasePool::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::AutoreleasePool::drain() +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(drain)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::AutoreleasePool::addObject(Object* pObject) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(addObject_), pObject); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::AutoreleasePool::showPools() +{ + Object::sendMessage(_NS_PRIVATE_CLS(NSAutoreleasePool), _NS_PRIVATE_SEL(showPools)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSBundle.hpp b/dist/include/metal_cpp/Foundation/NSBundle.hpp new file mode 100644 index 0000000..b9637f5 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSBundle.hpp @@ -0,0 +1,374 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSBundle.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSNotification.hpp" +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +_NS_CONST(NotificationName, BundleDidLoadNotification); +_NS_CONST(NotificationName, BundleResourceRequestLowDiskSpaceNotification); + +class String* LocalizedString(const String* pKey, const String*); +class String* LocalizedStringFromTable(const String* pKey, const String* pTbl, const String*); +class String* LocalizedStringFromTableInBundle(const String* pKey, const String* pTbl, const class Bundle* pBdle, const String*); +class String* LocalizedStringWithDefaultValue(const String* pKey, const String* pTbl, const class Bundle* pBdle, const String* pVal, const String*); + +class Bundle : public Referencing +{ +public: + static Bundle* mainBundle(); + + static Bundle* bundle(const class String* pPath); + static Bundle* bundle(const class URL* pURL); + + static class Array* allBundles(); + static class Array* allFrameworks(); + + static Bundle* alloc(); + + Bundle* init(const class String* pPath); + Bundle* init(const class URL* pURL); + + bool load(); + bool unload(); + + bool isLoaded() const; + + bool preflightAndReturnError(class Error** pError) const; + bool loadAndReturnError(class Error** pError); + + class URL* bundleURL() const; + class URL* resourceURL() const; + class URL* executableURL() const; + class URL* URLForAuxiliaryExecutable(const class String* pExecutableName) const; + + class URL* privateFrameworksURL() const; + class URL* sharedFrameworksURL() const; + class URL* sharedSupportURL() const; + class URL* builtInPlugInsURL() const; + class URL* appStoreReceiptURL() const; + + class String* bundlePath() const; + class String* resourcePath() const; + class String* executablePath() const; + class String* pathForAuxiliaryExecutable(const class String* pExecutableName) const; + + class String* privateFrameworksPath() const; + class String* sharedFrameworksPath() const; + class String* sharedSupportPath() const; + class String* builtInPlugInsPath() const; + + class String* bundleIdentifier() const; + class Dictionary* infoDictionary() const; + class Dictionary* localizedInfoDictionary() const; + class Object* objectForInfoDictionaryKey(const class String* pKey); + + class String* localizedString(const class String* pKey, const class String* pValue = nullptr, const class String* pTableName = nullptr) const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_PRIVATE_DEF_CONST(NS::NotificationName, BundleDidLoadNotification); +_NS_PRIVATE_DEF_CONST(NS::NotificationName, BundleResourceRequestLowDiskSpaceNotification); + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedString(const String* pKey, const String*) +{ + return Bundle::mainBundle()->localizedString(pKey, nullptr, nullptr); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedStringFromTable(const String* pKey, const String* pTbl, const String*) +{ + return Bundle::mainBundle()->localizedString(pKey, nullptr, pTbl); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedStringFromTableInBundle(const String* pKey, const String* pTbl, const Bundle* pBdl, const String*) +{ + return pBdl->localizedString(pKey, nullptr, pTbl); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::LocalizedStringWithDefaultValue(const String* pKey, const String* pTbl, const Bundle* pBdl, const String* pVal, const String*) +{ + return pBdl->localizedString(pKey, pVal, pTbl); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::mainBundle() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(mainBundle)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::bundle(const class String* pPath) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(bundleWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::bundle(const class URL* pURL) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(bundleWithURL_), pURL); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Bundle::allBundles() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(allBundles)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Bundle::allFrameworks() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(allFrameworks)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::alloc() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSBundle), _NS_PRIVATE_SEL(alloc)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::init(const String* pPath) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Bundle* NS::Bundle::init(const URL* pURL) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithURL_), pURL); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::load() +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(load)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::unload() +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unload)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::isLoaded() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isLoaded)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::preflightAndReturnError(Error** pError) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(preflightAndReturnError_), pError); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Bundle::loadAndReturnError(Error** pError) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(loadAndReturnError_), pError); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::bundleURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(bundleURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::resourceURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(resourceURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::executableURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(executableURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::URLForAuxiliaryExecutable(const String* pExecutableName) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(URLForAuxiliaryExecutable_), pExecutableName); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::privateFrameworksURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(privateFrameworksURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::sharedFrameworksURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedFrameworksURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::sharedSupportURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedSupportURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::builtInPlugInsURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(builtInPlugInsURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::Bundle::appStoreReceiptURL() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(appStoreReceiptURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::bundlePath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(bundlePath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::resourcePath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(resourcePath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::executablePath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(executablePath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::pathForAuxiliaryExecutable(const String* pExecutableName) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(pathForAuxiliaryExecutable_), pExecutableName); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::privateFrameworksPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(privateFrameworksPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::sharedFrameworksPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedFrameworksPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::sharedSupportPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(sharedSupportPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::builtInPlugInsPath() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(builtInPlugInsPath)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::bundleIdentifier() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(bundleIdentifier)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Bundle::infoDictionary() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(infoDictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Bundle::localizedInfoDictionary() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedInfoDictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::Bundle::objectForInfoDictionaryKey(const String* pKey) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(objectForInfoDictionaryKey_), pKey); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Bundle::localizedString(const String* pKey, const String* pValue /* = nullptr */, const String* pTableName /* = nullptr */) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedStringForKey_value_table_), pKey, pValue, pTableName); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSData.hpp b/dist/include/metal_cpp/Foundation/NSData.hpp new file mode 100644 index 0000000..3ad3606 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSData.hpp @@ -0,0 +1,54 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSData.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Data : public Copying +{ +public: + void* mutableBytes() const; + UInteger length() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void* NS::Data::mutableBytes() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(mutableBytes)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Data::length() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(length)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSDate.hpp b/dist/include/metal_cpp/Foundation/NSDate.hpp new file mode 100644 index 0000000..0a5ec7d --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSDate.hpp @@ -0,0 +1,53 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSDate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ + +using TimeInterval = double; + +class Date : public Copying +{ +public: + static Date* dateWithTimeIntervalSinceNow(TimeInterval secs); +}; + +} // NS + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Date* NS::Date::dateWithTimeIntervalSinceNow(NS::TimeInterval secs) +{ + return NS::Object::sendMessage(_NS_PRIVATE_CLS(NSDate), _NS_PRIVATE_SEL(dateWithTimeIntervalSinceNow_), secs); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- \ No newline at end of file diff --git a/dist/include/metal_cpp/Foundation/NSDefines.hpp b/dist/include/metal_cpp/Foundation/NSDefines.hpp new file mode 100644 index 0000000..38bbb56 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSDefines.hpp @@ -0,0 +1,45 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSDefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _NS_WEAK_IMPORT __attribute__((weak_import)) +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _NS_EXPORT __attribute__((visibility("hidden"))) +#else +#define _NS_EXPORT __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _NS_EXTERN extern "C" _NS_EXPORT +#define _NS_INLINE inline __attribute__((always_inline)) +#define _NS_PACKED __attribute__((packed)) + +#define _NS_CONST(type, name) _NS_EXTERN type const name +#define _NS_ENUM(type, name) enum name : type +#define _NS_OPTIONS(type, name) \ + using name = type; \ + enum : name + +#define _NS_CAST_TO_UINT(value) static_cast(value) +#define _NS_VALIDATE_SIZE(ns, name) static_assert(sizeof(ns::name) == sizeof(ns##name), "size mismatch " #ns "::" #name) +#define _NS_VALIDATE_ENUM(ns, name) static_assert(_NS_CAST_TO_UINT(ns::name) == _NS_CAST_TO_UINT(ns##name), "value mismatch " #ns "::" #name) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSDictionary.hpp b/dist/include/metal_cpp/Foundation/NSDictionary.hpp new file mode 100644 index 0000000..d4a1519 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSDictionary.hpp @@ -0,0 +1,128 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSDictionary.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSEnumerator.hpp" +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Dictionary : public NS::Copying +{ +public: + static Dictionary* dictionary(); + static Dictionary* dictionary(const Object* pObject, const Object* pKey); + static Dictionary* dictionary(const Object* const* pObjects, const Object* const* pKeys, UInteger count); + + static Dictionary* alloc(); + + Dictionary* init(); + Dictionary* init(const Object* const* pObjects, const Object* const* pKeys, UInteger count); + Dictionary* init(const class Coder* pCoder); + + template + Enumerator<_KeyType>* keyEnumerator() const; + + template + _Object* object(const Object* pKey) const; + UInteger count() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::dictionary() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSDictionary), _NS_PRIVATE_SEL(dictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::dictionary(const Object* pObject, const Object* pKey) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSDictionary), _NS_PRIVATE_SEL(dictionaryWithObject_forKey_), pObject, pKey); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::dictionary(const Object* const* pObjects, const Object* const* pKeys, UInteger count) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSDictionary), _NS_PRIVATE_SEL(dictionaryWithObjects_forKeys_count_), + pObjects, pKeys, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSDictionary)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::init(const Object* const* pObjects, const Object* const* pKeys, UInteger count) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithObjects_forKeys_count_), pObjects, pKeys, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Dictionary::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE NS::Enumerator<_KeyType>* NS::Dictionary::keyEnumerator() const +{ + return Object::sendMessage*>(this, _NS_PRIVATE_SEL(keyEnumerator)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Object* NS::Dictionary::object(const Object* pKey) const +{ + return Object::sendMessage<_Object*>(this, _NS_PRIVATE_SEL(objectForKey_), pKey); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Dictionary::count() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(count)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSEnumerator.hpp b/dist/include/metal_cpp/Foundation/NSEnumerator.hpp new file mode 100644 index 0000000..5a2500c --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSEnumerator.hpp @@ -0,0 +1,78 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSEnumerator.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +struct FastEnumerationState +{ + unsigned long state; + Object** itemsPtr; + unsigned long* mutationsPtr; + unsigned long extra[5]; +} _NS_PACKED; + +class FastEnumeration : public Referencing +{ +public: + NS::UInteger countByEnumerating(FastEnumerationState* pState, Object** pBuffer, NS::UInteger len); +}; + +template +class Enumerator : public Referencing, FastEnumeration> +{ +public: + _ObjectType* nextObject(); + class Array* allObjects(); +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::FastEnumeration::countByEnumerating(FastEnumerationState* pState, Object** pBuffer, NS::UInteger len) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(countByEnumeratingWithState_objects_count_), pState, pBuffer, len); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _ObjectType* NS::Enumerator<_ObjectType>::nextObject() +{ + return Object::sendMessage<_ObjectType*>(this, _NS_PRIVATE_SEL(nextObject)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE NS::Array* NS::Enumerator<_ObjectType>::allObjects() +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(allObjects)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSError.hpp b/dist/include/metal_cpp/Foundation/NSError.hpp new file mode 100644 index 0000000..ea331d4 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSError.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSError.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +using ErrorDomain = class String*; + +_NS_CONST(ErrorDomain, CocoaErrorDomain); +_NS_CONST(ErrorDomain, POSIXErrorDomain); +_NS_CONST(ErrorDomain, OSStatusErrorDomain); +_NS_CONST(ErrorDomain, MachErrorDomain); + +using ErrorUserInfoKey = class String*; + +_NS_CONST(ErrorUserInfoKey, UnderlyingErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedDescriptionKey); +_NS_CONST(ErrorUserInfoKey, LocalizedFailureReasonErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedRecoverySuggestionErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedRecoveryOptionsErrorKey); +_NS_CONST(ErrorUserInfoKey, RecoveryAttempterErrorKey); +_NS_CONST(ErrorUserInfoKey, HelpAnchorErrorKey); +_NS_CONST(ErrorUserInfoKey, DebugDescriptionErrorKey); +_NS_CONST(ErrorUserInfoKey, LocalizedFailureErrorKey); +_NS_CONST(ErrorUserInfoKey, StringEncodingErrorKey); +_NS_CONST(ErrorUserInfoKey, URLErrorKey); +_NS_CONST(ErrorUserInfoKey, FilePathErrorKey); + +class Error : public Copying +{ +public: + static Error* error(ErrorDomain domain, Integer code, class Dictionary* pDictionary); + + static Error* alloc(); + Error* init(); + Error* init(ErrorDomain domain, Integer code, class Dictionary* pDictionary); + + Integer code() const; + ErrorDomain domain() const; + class Dictionary* userInfo() const; + + class String* localizedDescription() const; + class Array* localizedRecoveryOptions() const; + class String* localizedRecoverySuggestion() const; + class String* localizedFailureReason() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, CocoaErrorDomain); +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, POSIXErrorDomain); +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, OSStatusErrorDomain); +_NS_PRIVATE_DEF_CONST(NS::ErrorDomain, MachErrorDomain); + +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, UnderlyingErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedDescriptionKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedFailureReasonErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedRecoverySuggestionErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedRecoveryOptionsErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, RecoveryAttempterErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, HelpAnchorErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, DebugDescriptionErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, LocalizedFailureErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, StringEncodingErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, URLErrorKey); +_NS_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, FilePathErrorKey); + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::error(ErrorDomain domain, Integer code, class Dictionary* pDictionary) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSError), _NS_PRIVATE_SEL(errorWithDomain_code_userInfo_), domain, code, pDictionary); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::alloc() +{ + return Object::alloc(_NS_PRIVATE_CLS(NSError)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::init() +{ + return Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Error* NS::Error::init(ErrorDomain domain, Integer code, class Dictionary* pDictionary) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithDomain_code_userInfo_), domain, code, pDictionary); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Integer NS::Error::code() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(code)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ErrorDomain NS::Error::domain() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(domain)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Error::userInfo() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(userInfo)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Error::localizedDescription() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedDescription)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::Error::localizedRecoveryOptions() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedRecoveryOptions)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Error::localizedRecoverySuggestion() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedRecoverySuggestion)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Error::localizedFailureReason() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(localizedFailureReason)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSLock.hpp b/dist/include/metal_cpp/Foundation/NSLock.hpp new file mode 100644 index 0000000..01df219 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSLock.hpp @@ -0,0 +1,118 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSLock.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" +#include "NSDate.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ + +template +class Locking : public _Base +{ +public: + void lock(); + void unlock(); +}; + +class Condition : public Locking +{ +public: + static Condition* alloc(); + + Condition* init(); + + void wait(); + bool waitUntilDate(Date* pLimit); + void signal(); + void broadcast(); +}; + +} // NS + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE void NS::Locking<_Class, _Base>::lock() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(lock)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE void NS::Locking<_Class, _Base>::unlock() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(unlock)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Condition* NS::Condition::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSCondition)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Condition* NS::Condition::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Condition::wait() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(wait)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Condition::waitUntilDate(NS::Date* pLimit) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(waitUntilDate_), pLimit); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Condition::signal() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(signal)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Condition::broadcast() +{ + NS::Object::sendMessage(this, _NS_PRIVATE_SEL(broadcast)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- \ No newline at end of file diff --git a/dist/include/metal_cpp/Foundation/NSNotification.hpp b/dist/include/metal_cpp/Foundation/NSNotification.hpp new file mode 100644 index 0000000..6b5be12 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSNotification.hpp @@ -0,0 +1,110 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSNotification.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSDictionary.hpp" +#include "NSObject.hpp" +#include "NSString.hpp" +#include "NSTypes.hpp" +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +using NotificationName = class String*; + +class Notification : public NS::Referencing +{ +public: + NS::String* name() const; + NS::Object* object() const; + NS::Dictionary* userInfo() const; +}; + +using ObserverBlock = void(^)(Notification*); +using ObserverFunction = std::function; + +class NotificationCenter : public NS::Referencing +{ + public: + static class NotificationCenter* defaultCenter(); + Object* addObserver(NotificationName name, Object* pObj, void* pQueue, ObserverBlock block); + Object* addObserver(NotificationName name, Object* pObj, void* pQueue, ObserverFunction &handler); + void removeObserver(Object* pObserver); + +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Notification::name() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(name)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::Notification::object() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(object)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::Notification::userInfo() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(userInfo)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::NotificationCenter* NS::NotificationCenter::defaultCenter() +{ + return NS::Object::sendMessage(_NS_PRIVATE_CLS(NSNotificationCenter), _NS_PRIVATE_SEL(defaultCenter)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::NotificationCenter::addObserver(NS::NotificationName name, Object* pObj, void* pQueue, NS::ObserverBlock block) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(addObserverName_object_queue_block_), name, pObj, pQueue, block); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::NotificationCenter::addObserver(NS::NotificationName name, Object* pObj, void* pQueue, NS::ObserverFunction &handler) +{ + __block ObserverFunction blockFunction = handler; + + return addObserver(name, pObj, pQueue, ^(NS::Notification* pNotif) {blockFunction(pNotif);}); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::NotificationCenter::removeObserver(Object* pObserver) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(removeObserver_), pObserver); +} + diff --git a/dist/include/metal_cpp/Foundation/NSNumber.hpp b/dist/include/metal_cpp/Foundation/NSNumber.hpp new file mode 100644 index 0000000..eec7cea --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSNumber.hpp @@ -0,0 +1,501 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSNumber.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObjCRuntime.hpp" +#include "NSObject.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class Value : public Copying +{ +public: + static Value* value(const void* pValue, const char* pType); + static Value* value(const void* pPointer); + + static Value* alloc(); + + Value* init(const void* pValue, const char* pType); + Value* init(const class Coder* pCoder); + + void getValue(void* pValue, UInteger size) const; + const char* objCType() const; + + bool isEqualToValue(Value* pValue) const; + void* pointerValue() const; +}; + +class Number : public Copying +{ +public: + static Number* number(char value); + static Number* number(unsigned char value); + static Number* number(short value); + static Number* number(unsigned short value); + static Number* number(int value); + static Number* number(unsigned int value); + static Number* number(long value); + static Number* number(unsigned long value); + static Number* number(long long value); + static Number* number(unsigned long long value); + static Number* number(float value); + static Number* number(double value); + static Number* number(bool value); + + static Number* alloc(); + + Number* init(const class Coder* pCoder); + Number* init(char value); + Number* init(unsigned char value); + Number* init(short value); + Number* init(unsigned short value); + Number* init(int value); + Number* init(unsigned int value); + Number* init(long value); + Number* init(unsigned long value); + Number* init(long long value); + Number* init(unsigned long long value); + Number* init(float value); + Number* init(double value); + Number* init(bool value); + + char charValue() const; + unsigned char unsignedCharValue() const; + short shortValue() const; + unsigned short unsignedShortValue() const; + int intValue() const; + unsigned int unsignedIntValue() const; + long longValue() const; + unsigned long unsignedLongValue() const; + long long longLongValue() const; + unsigned long long unsignedLongLongValue() const; + float floatValue() const; + double doubleValue() const; + bool boolValue() const; + Integer integerValue() const; + UInteger unsignedIntegerValue() const; + class String* stringValue() const; + + ComparisonResult compare(const Number* pOtherNumber) const; + bool isEqualToNumber(const Number* pNumber) const; + + class String* descriptionWithLocale(const Object* pLocale) const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::value(const void* pValue, const char* pType) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSValue), _NS_PRIVATE_SEL(valueWithBytes_objCType_), pValue, pType); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::value(const void* pPointer) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSValue), _NS_PRIVATE_SEL(valueWithPointer_), pPointer); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::init(const void* pValue, const char* pType) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithBytes_objCType_), pValue, pType); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Value* NS::Value::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::Value::getValue(void* pValue, UInteger size) const +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(getValue_size_), pValue, size); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::Value::objCType() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(objCType)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Value::isEqualToValue(Value* pValue) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isEqualToValue_), pValue); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void* NS::Value::pointerValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(pointerValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(char value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned char value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(short value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned short value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(int value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned int value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(long long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(unsigned long long value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithUnsignedLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(float value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithFloat_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(double value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithDouble_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::number(bool value) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSNumber), _NS_PRIVATE_SEL(numberWithBool_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSNumber)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(const Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(char value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned char value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedChar_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(short value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned short value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedShort_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(int value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned int value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedInt_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(long long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(unsigned long long value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithUnsignedLongLong_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(float value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithFloat_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(double value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithDouble_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Number* NS::Number::init(bool value) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithBool_), value); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE char NS::Number::charValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(charValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned char NS::Number::unsignedCharValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedCharValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE short NS::Number::shortValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(shortValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned short NS::Number::unsignedShortValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedShortValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE int NS::Number::intValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(intValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned int NS::Number::unsignedIntValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedIntValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE long NS::Number::longValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(longValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned long NS::Number::unsignedLongValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedLongValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE long long NS::Number::longLongValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(longLongValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned long long NS::Number::unsignedLongLongValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedLongLongValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE float NS::Number::floatValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(floatValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE double NS::Number::doubleValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(doubleValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Number::boolValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(boolValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Integer NS::Number::integerValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(integerValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Number::unsignedIntegerValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(unsignedIntegerValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Number::stringValue() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(stringValue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ComparisonResult NS::Number::compare(const Number* pOtherNumber) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(compare_), pOtherNumber); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Number::isEqualToNumber(const Number* pNumber) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isEqualToNumber_), pNumber); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Number::descriptionWithLocale(const Object* pLocale) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(descriptionWithLocale_), pLocale); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSObjCRuntime.hpp b/dist/include/metal_cpp/Foundation/NSObjCRuntime.hpp new file mode 100644 index 0000000..9a5364c --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSObjCRuntime.hpp @@ -0,0 +1,43 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSObjCRuntime.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ + +_NS_ENUM(Integer, ComparisonResult) { + OrderedAscending = -1L, + OrderedSame, + OrderedDescending +}; + +const Integer NotFound = IntegerMax; + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSObject.hpp b/dist/include/metal_cpp/Foundation/NSObject.hpp new file mode 100644 index 0000000..aff8e67 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSObject.hpp @@ -0,0 +1,302 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSObject.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +#include +#include + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +template +class _NS_EXPORT Referencing : public _Base +{ +public: + _Class* retain(); + void release(); + + _Class* autorelease(); + + UInteger retainCount() const; +}; + +template +class Copying : public Referencing<_Class, _Base> +{ +public: + _Class* copy() const; +}; + +template +class SecureCoding : public Referencing<_Class, _Base> +{ +}; + +class Object : public Referencing +{ +public: + UInteger hash() const; + bool isEqual(const Object* pObject) const; + + class String* description() const; + class String* debugDescription() const; + +protected: + friend class Referencing; + + template + static _Class* alloc(const char* pClassName); + template + static _Class* alloc(const void* pClass); + template + _Class* init(); + + template + static _Dst bridgingCast(const void* pObj); + static class MethodSignature* methodSignatureForSelector(const void* pObj, SEL selector); + static bool respondsToSelector(const void* pObj, SEL selector); + template + static constexpr bool doesRequireMsgSendStret(); + template + static _Ret sendMessage(const void* pObj, SEL selector, _Args... args); + template + static _Ret sendMessageSafe(const void* pObj, SEL selector, _Args... args); + +private: + Object() = delete; + Object(const Object&) = delete; + ~Object() = delete; + + Object& operator=(const Object&) = delete; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Referencing<_Class, _Base>::retain() +{ + return Object::sendMessage<_Class*>(this, _NS_PRIVATE_SEL(retain)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE void NS::Referencing<_Class, _Base>::release() +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(release)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Referencing<_Class, _Base>::autorelease() +{ + return Object::sendMessage<_Class*>(this, _NS_PRIVATE_SEL(autorelease)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE NS::UInteger NS::Referencing<_Class, _Base>::retainCount() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(retainCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Copying<_Class, _Base>::copy() const +{ + return Object::sendMessage<_Class*>(this, _NS_PRIVATE_SEL(copy)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Dst NS::Object::bridgingCast(const void* pObj) +{ +#ifdef __OBJC__ + return (__bridge _Dst)pObj; +#else + return (_Dst)pObj; +#endif // __OBJC__ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE constexpr bool NS::Object::doesRequireMsgSendStret() +{ +#if (defined(__i386__) || defined(__x86_64__)) + constexpr size_t kStructLimit = (sizeof(std::uintptr_t) << 1); + + return sizeof(_Type) > kStructLimit; +#elif defined(__arm64__) + return false; +#elif defined(__arm__) + constexpr size_t kStructLimit = sizeof(std::uintptr_t); + + return std::is_class_v<_Type> && (sizeof(_Type) > kStructLimit); +#else +#error "Unsupported architecture!" +#endif +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template <> +_NS_INLINE constexpr bool NS::Object::doesRequireMsgSendStret() +{ + return false; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Ret NS::Object::sendMessage(const void* pObj, SEL selector, _Args... args) +{ +#if (defined(__i386__) || defined(__x86_64__)) + if constexpr (std::is_floating_point<_Ret>()) + { + using SendMessageProcFpret = _Ret (*)(const void*, SEL, _Args...); + + const SendMessageProcFpret pProc = reinterpret_cast(&objc_msgSend_fpret); + + return (*pProc)(pObj, selector, args...); + } + else +#endif // ( defined( __i386__ ) || defined( __x86_64__ ) ) +#if !defined(__arm64__) + if constexpr (doesRequireMsgSendStret<_Ret>()) + { + using SendMessageProcStret = void (*)(_Ret*, const void*, SEL, _Args...); + + const SendMessageProcStret pProc = reinterpret_cast(&objc_msgSend_stret); + _Ret ret; + + (*pProc)(&ret, pObj, selector, args...); + + return ret; + } + else +#endif // !defined( __arm64__ ) + { + using SendMessageProc = _Ret (*)(const void*, SEL, _Args...); + + const SendMessageProc pProc = reinterpret_cast(&objc_msgSend); + + return (*pProc)(pObj, selector, args...); + } +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::MethodSignature* NS::Object::methodSignatureForSelector(const void* pObj, SEL selector) +{ + return sendMessage(pObj, _NS_PRIVATE_SEL(methodSignatureForSelector_), selector); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Object::respondsToSelector(const void* pObj, SEL selector) +{ + return sendMessage(pObj, _NS_PRIVATE_SEL(respondsToSelector_), selector); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Ret NS::Object::sendMessageSafe(const void* pObj, SEL selector, _Args... args) +{ + if ((respondsToSelector(pObj, selector)) || (nullptr != methodSignatureForSelector(pObj, selector))) + { + return sendMessage<_Ret>(pObj, selector, args...); + } + + if constexpr (!std::is_void<_Ret>::value) + { + return _Ret(0); + } +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Object::alloc(const char* pClassName) +{ + return sendMessage<_Class*>(objc_lookUpClass(pClassName), _NS_PRIVATE_SEL(alloc)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Object::alloc(const void* pClass) +{ + return sendMessage<_Class*>(pClass, _NS_PRIVATE_SEL(alloc)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +template +_NS_INLINE _Class* NS::Object::init() +{ + return sendMessage<_Class*>(this, _NS_PRIVATE_SEL(init)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Object::hash() const +{ + return sendMessage(this, _NS_PRIVATE_SEL(hash)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Object::isEqual(const Object* pObject) const +{ + return sendMessage(this, _NS_PRIVATE_SEL(isEqual_), pObject); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Object::description() const +{ + return sendMessage(this, _NS_PRIVATE_SEL(description)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::Object::debugDescription() const +{ + return sendMessageSafe(this, _NS_PRIVATE_SEL(debugDescription)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSPrivate.hpp b/dist/include/metal_cpp/Foundation/NSPrivate.hpp new file mode 100644 index 0000000..f8d8700 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSPrivate.hpp @@ -0,0 +1,531 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _NS_PRIVATE_CLS(symbol) (Private::Class::s_k##symbol) +#define _NS_PRIVATE_SEL(accessor) (Private::Selector::s_k##accessor) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined(NS_PRIVATE_IMPLEMENTATION) + +#include + +namespace NS::Private +{ + template + inline _Type const LoadSymbol(const char* pSymbol) + { + const _Type* pAddress = static_cast<_Type*>(dlsym(RTLD_DEFAULT, pSymbol)); + + return pAddress ? *pAddress : _Type(); + } +} // NS::Private + +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _NS_PRIVATE_VISIBILITY __attribute__((visibility("hidden"))) +#else +#define _NS_PRIVATE_VISIBILITY __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _NS_PRIVATE_IMPORT __attribute__((weak_import)) + +#ifdef __OBJC__ +#define _NS_PRIVATE_OBJC_LOOKUP_CLASS(symbol) ((__bridge void*)objc_lookUpClass(#symbol)) +#define _NS_PRIVATE_OBJC_GET_PROTOCOL(symbol) ((__bridge void*)objc_getProtocol(#symbol)) +#else +#define _NS_PRIVATE_OBJC_LOOKUP_CLASS(symbol) objc_lookUpClass(#symbol) +#define _NS_PRIVATE_OBJC_GET_PROTOCOL(symbol) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _NS_PRIVATE_DEF_CLS(symbol) void* s_k##symbol _NS_PRIVATE_VISIBILITY = _NS_PRIVATE_OBJC_LOOKUP_CLASS(symbol) +#define _NS_PRIVATE_DEF_PRO(symbol) void* s_k##symbol _NS_PRIVATE_VISIBILITY = _NS_PRIVATE_OBJC_GET_PROTOCOL(symbol) +#define _NS_PRIVATE_DEF_SEL(accessor, symbol) SEL s_k##accessor _NS_PRIVATE_VISIBILITY = sel_registerName(symbol) + +#if defined(__MAC_26_0) || defined(__IPHONE_26_0) || defined(__TVOS_26_0) +#define _NS_PRIVATE_DEF_CONST(type, symbol) \ + _NS_EXTERN type const NS##symbol _NS_PRIVATE_IMPORT; \ + type const NS::symbol = (nullptr != &NS##symbol) ? NS##symbol : type() +#else +#define _NS_PRIVATE_DEF_CONST(type, symbol) \ + _NS_EXTERN type const MTL##symbol _NS_PRIVATE_IMPORT; \ + type const NS::symbol = Private::LoadSymbol("NS" #symbol) +#endif + +#else + +#define _NS_PRIVATE_DEF_CLS(symbol) extern void* s_k##symbol +#define _NS_PRIVATE_DEF_PRO(symbol) extern void* s_k##symbol +#define _NS_PRIVATE_DEF_SEL(accessor, symbol) extern SEL s_k##accessor +#define _NS_PRIVATE_DEF_CONST(type, symbol) extern type const NS::symbol + +#endif // NS_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +namespace Private +{ + namespace Class + { + + _NS_PRIVATE_DEF_CLS(NSArray); + _NS_PRIVATE_DEF_CLS(NSAutoreleasePool); + _NS_PRIVATE_DEF_CLS(NSBundle); + _NS_PRIVATE_DEF_CLS(NSCondition); + _NS_PRIVATE_DEF_CLS(NSDate); + _NS_PRIVATE_DEF_CLS(NSDictionary); + _NS_PRIVATE_DEF_CLS(NSError); + _NS_PRIVATE_DEF_CLS(NSNotificationCenter); + _NS_PRIVATE_DEF_CLS(NSNumber); + _NS_PRIVATE_DEF_CLS(NSObject); + _NS_PRIVATE_DEF_CLS(NSProcessInfo); + _NS_PRIVATE_DEF_CLS(NSSet); + _NS_PRIVATE_DEF_CLS(NSString); + _NS_PRIVATE_DEF_CLS(NSURL); + _NS_PRIVATE_DEF_CLS(NSValue); + + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +namespace Private +{ + namespace Protocol + { + + } // Protocol +} // Private +} // NS + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +namespace Private +{ + namespace Selector + { + + _NS_PRIVATE_DEF_SEL(addObject_, + "addObject:"); + _NS_PRIVATE_DEF_SEL(addObserverName_object_queue_block_, + "addObserverForName:object:queue:usingBlock:"); + _NS_PRIVATE_DEF_SEL(activeProcessorCount, + "activeProcessorCount"); + _NS_PRIVATE_DEF_SEL(allBundles, + "allBundles"); + _NS_PRIVATE_DEF_SEL(allFrameworks, + "allFrameworks"); + _NS_PRIVATE_DEF_SEL(allObjects, + "allObjects"); + _NS_PRIVATE_DEF_SEL(alloc, + "alloc"); + _NS_PRIVATE_DEF_SEL(appStoreReceiptURL, + "appStoreReceiptURL"); + _NS_PRIVATE_DEF_SEL(arguments, + "arguments"); + _NS_PRIVATE_DEF_SEL(array, + "array"); + _NS_PRIVATE_DEF_SEL(arrayWithObject_, + "arrayWithObject:"); + _NS_PRIVATE_DEF_SEL(arrayWithObjects_count_, + "arrayWithObjects:count:"); + _NS_PRIVATE_DEF_SEL(automaticTerminationSupportEnabled, + "automaticTerminationSupportEnabled"); + _NS_PRIVATE_DEF_SEL(autorelease, + "autorelease"); + _NS_PRIVATE_DEF_SEL(beginActivityWithOptions_reason_, + "beginActivityWithOptions:reason:"); + _NS_PRIVATE_DEF_SEL(boolValue, + "boolValue"); + _NS_PRIVATE_DEF_SEL(broadcast, + "broadcast"); + _NS_PRIVATE_DEF_SEL(builtInPlugInsPath, + "builtInPlugInsPath"); + _NS_PRIVATE_DEF_SEL(builtInPlugInsURL, + "builtInPlugInsURL"); + _NS_PRIVATE_DEF_SEL(bundleIdentifier, + "bundleIdentifier"); + _NS_PRIVATE_DEF_SEL(bundlePath, + "bundlePath"); + _NS_PRIVATE_DEF_SEL(bundleURL, + "bundleURL"); + _NS_PRIVATE_DEF_SEL(bundleWithPath_, + "bundleWithPath:"); + _NS_PRIVATE_DEF_SEL(bundleWithURL_, + "bundleWithURL:"); + _NS_PRIVATE_DEF_SEL(caseInsensitiveCompare_, + "caseInsensitiveCompare:"); + _NS_PRIVATE_DEF_SEL(characterAtIndex_, + "characterAtIndex:"); + _NS_PRIVATE_DEF_SEL(charValue, + "charValue"); + _NS_PRIVATE_DEF_SEL(countByEnumeratingWithState_objects_count_, + "countByEnumeratingWithState:objects:count:"); + _NS_PRIVATE_DEF_SEL(cStringUsingEncoding_, + "cStringUsingEncoding:"); + _NS_PRIVATE_DEF_SEL(code, + "code"); + _NS_PRIVATE_DEF_SEL(compare_, + "compare:"); + _NS_PRIVATE_DEF_SEL(copy, + "copy"); + _NS_PRIVATE_DEF_SEL(count, + "count"); + _NS_PRIVATE_DEF_SEL(dateWithTimeIntervalSinceNow_, + "dateWithTimeIntervalSinceNow:"); + _NS_PRIVATE_DEF_SEL(defaultCenter, + "defaultCenter"); + _NS_PRIVATE_DEF_SEL(descriptionWithLocale_, + "descriptionWithLocale:"); + _NS_PRIVATE_DEF_SEL(disableAutomaticTermination_, + "disableAutomaticTermination:"); + _NS_PRIVATE_DEF_SEL(disableSuddenTermination, + "disableSuddenTermination"); + _NS_PRIVATE_DEF_SEL(debugDescription, + "debugDescription"); + _NS_PRIVATE_DEF_SEL(description, + "description"); + _NS_PRIVATE_DEF_SEL(dictionary, + "dictionary"); + _NS_PRIVATE_DEF_SEL(dictionaryWithObject_forKey_, + "dictionaryWithObject:forKey:"); + _NS_PRIVATE_DEF_SEL(dictionaryWithObjects_forKeys_count_, + "dictionaryWithObjects:forKeys:count:"); + _NS_PRIVATE_DEF_SEL(domain, + "domain"); + _NS_PRIVATE_DEF_SEL(doubleValue, + "doubleValue"); + _NS_PRIVATE_DEF_SEL(drain, + "drain"); + _NS_PRIVATE_DEF_SEL(enableAutomaticTermination_, + "enableAutomaticTermination:"); + _NS_PRIVATE_DEF_SEL(enableSuddenTermination, + "enableSuddenTermination"); + _NS_PRIVATE_DEF_SEL(endActivity_, + "endActivity:"); + _NS_PRIVATE_DEF_SEL(environment, + "environment"); + _NS_PRIVATE_DEF_SEL(errorWithDomain_code_userInfo_, + "errorWithDomain:code:userInfo:"); + _NS_PRIVATE_DEF_SEL(executablePath, + "executablePath"); + _NS_PRIVATE_DEF_SEL(executableURL, + "executableURL"); + _NS_PRIVATE_DEF_SEL(fileSystemRepresentation, + "fileSystemRepresentation"); + _NS_PRIVATE_DEF_SEL(fileURLWithPath_, + "fileURLWithPath:"); + _NS_PRIVATE_DEF_SEL(floatValue, + "floatValue"); + _NS_PRIVATE_DEF_SEL(fullUserName, + "fullUserName"); + _NS_PRIVATE_DEF_SEL(getValue_size_, + "getValue:size:"); + _NS_PRIVATE_DEF_SEL(globallyUniqueString, + "globallyUniqueString"); + _NS_PRIVATE_DEF_SEL(hash, + "hash"); + _NS_PRIVATE_DEF_SEL(hasPerformanceProfile_, + "hasPerformanceProfile:"); + _NS_PRIVATE_DEF_SEL(hostName, + "hostName"); + _NS_PRIVATE_DEF_SEL(infoDictionary, + "infoDictionary"); + _NS_PRIVATE_DEF_SEL(init, + "init"); + _NS_PRIVATE_DEF_SEL(initFileURLWithPath_, + "initFileURLWithPath:"); + _NS_PRIVATE_DEF_SEL(initWithBool_, + "initWithBool:"); + _NS_PRIVATE_DEF_SEL(initWithBytes_objCType_, + "initWithBytes:objCType:"); + _NS_PRIVATE_DEF_SEL(initWithBytesNoCopy_length_encoding_freeWhenDone_, + "initWithBytesNoCopy:length:encoding:freeWhenDone:"); + _NS_PRIVATE_DEF_SEL(initWithChar_, + "initWithChar:"); + _NS_PRIVATE_DEF_SEL(initWithCoder_, + "initWithCoder:"); + _NS_PRIVATE_DEF_SEL(initWithCString_encoding_, + "initWithCString:encoding:"); + _NS_PRIVATE_DEF_SEL(initWithDomain_code_userInfo_, + "initWithDomain:code:userInfo:"); + _NS_PRIVATE_DEF_SEL(initWithDouble_, + "initWithDouble:"); + _NS_PRIVATE_DEF_SEL(initWithFloat_, + "initWithFloat:"); + _NS_PRIVATE_DEF_SEL(initWithInt_, + "initWithInt:"); + _NS_PRIVATE_DEF_SEL(initWithLong_, + "initWithLong:"); + _NS_PRIVATE_DEF_SEL(initWithLongLong_, + "initWithLongLong:"); + _NS_PRIVATE_DEF_SEL(initWithObjects_count_, + "initWithObjects:count:"); + _NS_PRIVATE_DEF_SEL(initWithObjects_forKeys_count_, + "initWithObjects:forKeys:count:"); + _NS_PRIVATE_DEF_SEL(initWithPath_, + "initWithPath:"); + _NS_PRIVATE_DEF_SEL(initWithShort_, + "initWithShort:"); + _NS_PRIVATE_DEF_SEL(initWithString_, + "initWithString:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedChar_, + "initWithUnsignedChar:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedInt_, + "initWithUnsignedInt:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedLong_, + "initWithUnsignedLong:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedLongLong_, + "initWithUnsignedLongLong:"); + _NS_PRIVATE_DEF_SEL(initWithUnsignedShort_, + "initWithUnsignedShort:"); + _NS_PRIVATE_DEF_SEL(initWithURL_, + "initWithURL:"); + _NS_PRIVATE_DEF_SEL(integerValue, + "integerValue"); + _NS_PRIVATE_DEF_SEL(intValue, + "intValue"); + _NS_PRIVATE_DEF_SEL(isDeviceCertified_, + "isDeviceCertifiedFor:"); + _NS_PRIVATE_DEF_SEL(isEqual_, + "isEqual:"); + _NS_PRIVATE_DEF_SEL(isEqualToNumber_, + "isEqualToNumber:"); + _NS_PRIVATE_DEF_SEL(isEqualToString_, + "isEqualToString:"); + _NS_PRIVATE_DEF_SEL(isEqualToValue_, + "isEqualToValue:"); + _NS_PRIVATE_DEF_SEL(isiOSAppOnMac, + "isiOSAppOnMac"); + _NS_PRIVATE_DEF_SEL(isLoaded, + "isLoaded"); + _NS_PRIVATE_DEF_SEL(isLowPowerModeEnabled, + "isLowPowerModeEnabled"); + _NS_PRIVATE_DEF_SEL(isMacCatalystApp, + "isMacCatalystApp"); + _NS_PRIVATE_DEF_SEL(isOperatingSystemAtLeastVersion_, + "isOperatingSystemAtLeastVersion:"); + _NS_PRIVATE_DEF_SEL(keyEnumerator, + "keyEnumerator"); + _NS_PRIVATE_DEF_SEL(length, + "length"); + _NS_PRIVATE_DEF_SEL(lengthOfBytesUsingEncoding_, + "lengthOfBytesUsingEncoding:"); + _NS_PRIVATE_DEF_SEL(load, + "load"); + _NS_PRIVATE_DEF_SEL(loadAndReturnError_, + "loadAndReturnError:"); + _NS_PRIVATE_DEF_SEL(localizedDescription, + "localizedDescription"); + _NS_PRIVATE_DEF_SEL(localizedFailureReason, + "localizedFailureReason"); + _NS_PRIVATE_DEF_SEL(localizedInfoDictionary, + "localizedInfoDictionary"); + _NS_PRIVATE_DEF_SEL(localizedRecoveryOptions, + "localizedRecoveryOptions"); + _NS_PRIVATE_DEF_SEL(localizedRecoverySuggestion, + "localizedRecoverySuggestion"); + _NS_PRIVATE_DEF_SEL(localizedStringForKey_value_table_, + "localizedStringForKey:value:table:"); + _NS_PRIVATE_DEF_SEL(lock, + "lock"); + _NS_PRIVATE_DEF_SEL(longValue, + "longValue"); + _NS_PRIVATE_DEF_SEL(longLongValue, + "longLongValue"); + _NS_PRIVATE_DEF_SEL(mainBundle, + "mainBundle"); + _NS_PRIVATE_DEF_SEL(maximumLengthOfBytesUsingEncoding_, + "maximumLengthOfBytesUsingEncoding:"); + _NS_PRIVATE_DEF_SEL(methodSignatureForSelector_, + "methodSignatureForSelector:"); + _NS_PRIVATE_DEF_SEL(mutableBytes, + "mutableBytes"); + _NS_PRIVATE_DEF_SEL(name, + "name"); + _NS_PRIVATE_DEF_SEL(nextObject, + "nextObject"); + _NS_PRIVATE_DEF_SEL(numberWithBool_, + "numberWithBool:"); + _NS_PRIVATE_DEF_SEL(numberWithChar_, + "numberWithChar:"); + _NS_PRIVATE_DEF_SEL(numberWithDouble_, + "numberWithDouble:"); + _NS_PRIVATE_DEF_SEL(numberWithFloat_, + "numberWithFloat:"); + _NS_PRIVATE_DEF_SEL(numberWithInt_, + "numberWithInt:"); + _NS_PRIVATE_DEF_SEL(numberWithLong_, + "numberWithLong:"); + _NS_PRIVATE_DEF_SEL(numberWithLongLong_, + "numberWithLongLong:"); + _NS_PRIVATE_DEF_SEL(numberWithShort_, + "numberWithShort:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedChar_, + "numberWithUnsignedChar:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedInt_, + "numberWithUnsignedInt:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedLong_, + "numberWithUnsignedLong:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedLongLong_, + "numberWithUnsignedLongLong:"); + _NS_PRIVATE_DEF_SEL(numberWithUnsignedShort_, + "numberWithUnsignedShort:"); + _NS_PRIVATE_DEF_SEL(objCType, + "objCType"); + _NS_PRIVATE_DEF_SEL(object, + "object"); + _NS_PRIVATE_DEF_SEL(objectAtIndex_, + "objectAtIndex:"); + _NS_PRIVATE_DEF_SEL(objectEnumerator, + "objectEnumerator"); + _NS_PRIVATE_DEF_SEL(objectForInfoDictionaryKey_, + "objectForInfoDictionaryKey:"); + _NS_PRIVATE_DEF_SEL(objectForKey_, + "objectForKey:"); + _NS_PRIVATE_DEF_SEL(operatingSystem, + "operatingSystem"); + _NS_PRIVATE_DEF_SEL(operatingSystemVersion, + "operatingSystemVersion"); + _NS_PRIVATE_DEF_SEL(operatingSystemVersionString, + "operatingSystemVersionString"); + _NS_PRIVATE_DEF_SEL(pathForAuxiliaryExecutable_, + "pathForAuxiliaryExecutable:"); + _NS_PRIVATE_DEF_SEL(performActivityWithOptions_reason_usingBlock_, + "performActivityWithOptions:reason:usingBlock:"); + _NS_PRIVATE_DEF_SEL(performExpiringActivityWithReason_usingBlock_, + "performExpiringActivityWithReason:usingBlock:"); + _NS_PRIVATE_DEF_SEL(physicalMemory, + "physicalMemory"); + _NS_PRIVATE_DEF_SEL(pointerValue, + "pointerValue"); + _NS_PRIVATE_DEF_SEL(preflightAndReturnError_, + "preflightAndReturnError:"); + _NS_PRIVATE_DEF_SEL(privateFrameworksPath, + "privateFrameworksPath"); + _NS_PRIVATE_DEF_SEL(privateFrameworksURL, + "privateFrameworksURL"); + _NS_PRIVATE_DEF_SEL(processIdentifier, + "processIdentifier"); + _NS_PRIVATE_DEF_SEL(processInfo, + "processInfo"); + _NS_PRIVATE_DEF_SEL(processName, + "processName"); + _NS_PRIVATE_DEF_SEL(processorCount, + "processorCount"); + _NS_PRIVATE_DEF_SEL(rangeOfString_options_, + "rangeOfString:options:"); + _NS_PRIVATE_DEF_SEL(release, + "release"); + _NS_PRIVATE_DEF_SEL(removeObserver_, + "removeObserver:"); + _NS_PRIVATE_DEF_SEL(resourcePath, + "resourcePath"); + _NS_PRIVATE_DEF_SEL(resourceURL, + "resourceURL"); + _NS_PRIVATE_DEF_SEL(respondsToSelector_, + "respondsToSelector:"); + _NS_PRIVATE_DEF_SEL(retain, + "retain"); + _NS_PRIVATE_DEF_SEL(retainCount, + "retainCount"); + _NS_PRIVATE_DEF_SEL(setAutomaticTerminationSupportEnabled_, + "setAutomaticTerminationSupportEnabled:"); + _NS_PRIVATE_DEF_SEL(setProcessName_, + "setProcessName:"); + _NS_PRIVATE_DEF_SEL(sharedFrameworksPath, + "sharedFrameworksPath"); + _NS_PRIVATE_DEF_SEL(sharedFrameworksURL, + "sharedFrameworksURL"); + _NS_PRIVATE_DEF_SEL(sharedSupportPath, + "sharedSupportPath"); + _NS_PRIVATE_DEF_SEL(sharedSupportURL, + "sharedSupportURL"); + _NS_PRIVATE_DEF_SEL(shortValue, + "shortValue"); + _NS_PRIVATE_DEF_SEL(showPools, + "showPools"); + _NS_PRIVATE_DEF_SEL(signal, + "signal"); + _NS_PRIVATE_DEF_SEL(string, + "string"); + _NS_PRIVATE_DEF_SEL(stringValue, + "stringValue"); + _NS_PRIVATE_DEF_SEL(stringWithString_, + "stringWithString:"); + _NS_PRIVATE_DEF_SEL(stringWithCString_encoding_, + "stringWithCString:encoding:"); + _NS_PRIVATE_DEF_SEL(stringByAppendingString_, + "stringByAppendingString:"); + _NS_PRIVATE_DEF_SEL(systemUptime, + "systemUptime"); + _NS_PRIVATE_DEF_SEL(thermalState, + "thermalState"); + _NS_PRIVATE_DEF_SEL(unload, + "unload"); + _NS_PRIVATE_DEF_SEL(unlock, + "unlock"); + _NS_PRIVATE_DEF_SEL(unsignedCharValue, + "unsignedCharValue"); + _NS_PRIVATE_DEF_SEL(unsignedIntegerValue, + "unsignedIntegerValue"); + _NS_PRIVATE_DEF_SEL(unsignedIntValue, + "unsignedIntValue"); + _NS_PRIVATE_DEF_SEL(unsignedLongValue, + "unsignedLongValue"); + _NS_PRIVATE_DEF_SEL(unsignedLongLongValue, + "unsignedLongLongValue"); + _NS_PRIVATE_DEF_SEL(unsignedShortValue, + "unsignedShortValue"); + _NS_PRIVATE_DEF_SEL(URLForAuxiliaryExecutable_, + "URLForAuxiliaryExecutable:"); + _NS_PRIVATE_DEF_SEL(userInfo, + "userInfo"); + _NS_PRIVATE_DEF_SEL(userName, + "userName"); + _NS_PRIVATE_DEF_SEL(UTF8String, + "UTF8String"); + _NS_PRIVATE_DEF_SEL(valueWithBytes_objCType_, + "valueWithBytes:objCType:"); + _NS_PRIVATE_DEF_SEL(valueWithPointer_, + "valueWithPointer:"); + _NS_PRIVATE_DEF_SEL(wait, + "wait"); + _NS_PRIVATE_DEF_SEL(waitUntilDate_, + "waitUntilDate:"); + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSProcessInfo.hpp b/dist/include/metal_cpp/Foundation/NSProcessInfo.hpp new file mode 100644 index 0000000..09c212d --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSProcessInfo.hpp @@ -0,0 +1,386 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSProcessInfo.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSNotification.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +_NS_CONST(NotificationName, ProcessInfoThermalStateDidChangeNotification); +_NS_CONST(NotificationName, ProcessInfoPowerStateDidChangeNotification); +_NS_CONST(NotificationName, ProcessInfoPerformanceProfileDidChangeNotification); + +_NS_ENUM(NS::Integer, ProcessInfoThermalState) { + ProcessInfoThermalStateNominal = 0, + ProcessInfoThermalStateFair = 1, + ProcessInfoThermalStateSerious = 2, + ProcessInfoThermalStateCritical = 3 +}; + +_NS_OPTIONS(std::uint64_t, ActivityOptions) { + ActivityIdleDisplaySleepDisabled = (1ULL << 40), + ActivityIdleSystemSleepDisabled = (1ULL << 20), + ActivitySuddenTerminationDisabled = (1ULL << 14), + ActivityAutomaticTerminationDisabled = (1ULL << 15), + ActivityUserInitiated = (0x00FFFFFFULL | ActivityIdleSystemSleepDisabled), + ActivityUserInitiatedAllowingIdleSystemSleep = (ActivityUserInitiated & ~ActivityIdleSystemSleepDisabled), + ActivityBackground = 0x000000FFULL, + ActivityLatencyCritical = 0xFF00000000ULL, +}; + +typedef NS::Integer DeviceCertification; +_NS_CONST(DeviceCertification, DeviceCertificationiPhonePerformanceGaming); + +typedef NS::Integer ProcessPerformanceProfile; +_NS_CONST(ProcessPerformanceProfile, ProcessPerformanceProfileDefault); +_NS_CONST(ProcessPerformanceProfile, ProcessPerformanceProfileSustained); + +class ProcessInfo : public Referencing +{ +public: + static ProcessInfo* processInfo(); + + class Array* arguments() const; + class Dictionary* environment() const; + class String* hostName() const; + class String* processName() const; + void setProcessName(const String* pString); + int processIdentifier() const; + class String* globallyUniqueString() const; + + class String* userName() const; + class String* fullUserName() const; + + UInteger operatingSystem() const; + OperatingSystemVersion operatingSystemVersion() const; + class String* operatingSystemVersionString() const; + bool isOperatingSystemAtLeastVersion(OperatingSystemVersion version) const; + + UInteger processorCount() const; + UInteger activeProcessorCount() const; + unsigned long long physicalMemory() const; + TimeInterval systemUptime() const; + + void disableSuddenTermination(); + void enableSuddenTermination(); + + void disableAutomaticTermination(const class String* pReason); + void enableAutomaticTermination(const class String* pReason); + bool automaticTerminationSupportEnabled() const; + void setAutomaticTerminationSupportEnabled(bool enabled); + + class Object* beginActivity(ActivityOptions options, const class String* pReason); + void endActivity(class Object* pActivity); + void performActivity(ActivityOptions options, const class String* pReason, void (^block)(void)); + void performActivity(ActivityOptions options, const class String* pReason, const std::function& func); + void performExpiringActivity(const class String* pReason, void (^block)(bool expired)); + void performExpiringActivity(const class String* pReason, const std::function& func); + + ProcessInfoThermalState thermalState() const; + bool isLowPowerModeEnabled() const; + + bool isiOSAppOnMac() const; + bool isMacCatalystApp() const; + + bool isDeviceCertified(DeviceCertification performanceTier) const; + bool hasPerformanceProfile(ProcessPerformanceProfile performanceProfile) const; + +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoThermalStateDidChangeNotification); +_NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoPowerStateDidChangeNotification); + +// The linker searches for these symbols in the Metal framework, be sure to link it in as well: +_NS_PRIVATE_DEF_CONST(NS::NotificationName, ProcessInfoPerformanceProfileDidChangeNotification); +_NS_PRIVATE_DEF_CONST(NS::DeviceCertification, DeviceCertificationiPhonePerformanceGaming); +_NS_PRIVATE_DEF_CONST(NS::ProcessPerformanceProfile, ProcessPerformanceProfileDefault); +_NS_PRIVATE_DEF_CONST(NS::ProcessPerformanceProfile, ProcessPerformanceProfileSustained); + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ProcessInfo* NS::ProcessInfo::processInfo() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSProcessInfo), _NS_PRIVATE_SEL(processInfo)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Array* NS::ProcessInfo::arguments() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(arguments)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Dictionary* NS::ProcessInfo::environment() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(environment)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::hostName() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(hostName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::processName() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(processName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::setProcessName(const String* pString) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(setProcessName_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE int NS::ProcessInfo::processIdentifier() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(processIdentifier)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::globallyUniqueString() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(globallyUniqueString)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::userName() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(userName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::fullUserName() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(fullUserName)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::ProcessInfo::operatingSystem() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(operatingSystem)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::OperatingSystemVersion NS::ProcessInfo::operatingSystemVersion() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(operatingSystemVersion)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::ProcessInfo::operatingSystemVersionString() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(operatingSystemVersionString)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isOperatingSystemAtLeastVersion(OperatingSystemVersion version) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isOperatingSystemAtLeastVersion_), version); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::ProcessInfo::processorCount() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(processorCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::ProcessInfo::activeProcessorCount() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(activeProcessorCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE unsigned long long NS::ProcessInfo::physicalMemory() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(physicalMemory)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::TimeInterval NS::ProcessInfo::systemUptime() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(systemUptime)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::disableSuddenTermination() +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(disableSuddenTermination)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::enableSuddenTermination() +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(enableSuddenTermination)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::disableAutomaticTermination(const String* pReason) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(disableAutomaticTermination_), pReason); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::enableAutomaticTermination(const String* pReason) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(enableAutomaticTermination_), pReason); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::automaticTerminationSupportEnabled() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(automaticTerminationSupportEnabled)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::setAutomaticTerminationSupportEnabled(bool enabled) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(setAutomaticTerminationSupportEnabled_), enabled); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Object* NS::ProcessInfo::beginActivity(ActivityOptions options, const String* pReason) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(beginActivityWithOptions_reason_), options, pReason); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::endActivity(Object* pActivity) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(endActivity_), pActivity); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performActivity(ActivityOptions options, const String* pReason, void (^block)(void)) +{ + Object::sendMessage(this, _NS_PRIVATE_SEL(performActivityWithOptions_reason_usingBlock_), options, pReason, block); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performActivity(ActivityOptions options, const String* pReason, const std::function& function) +{ + __block std::function blockFunction = function; + + performActivity(options, pReason, ^() { blockFunction(); }); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performExpiringActivity(const String* pReason, void (^block)(bool expired)) +{ + Object::sendMessageSafe(this, _NS_PRIVATE_SEL(performExpiringActivityWithReason_usingBlock_), pReason, block); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE void NS::ProcessInfo::performExpiringActivity(const String* pReason, const std::function& function) +{ + __block std::function blockFunction = function; + + performExpiringActivity(pReason, ^(bool expired) { blockFunction(expired); }); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ProcessInfoThermalState NS::ProcessInfo::thermalState() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(thermalState)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isLowPowerModeEnabled() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isLowPowerModeEnabled)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isiOSAppOnMac() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isiOSAppOnMac)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isMacCatalystApp() const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isMacCatalystApp)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::isDeviceCertified(DeviceCertification performanceTier) const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(isDeviceCertified_), performanceTier); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::ProcessInfo::hasPerformanceProfile(ProcessPerformanceProfile performanceProfile) const +{ + return Object::sendMessageSafe(this, _NS_PRIVATE_SEL(hasPerformanceProfile_), performanceProfile); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSRange.hpp b/dist/include/metal_cpp/Foundation/NSRange.hpp new file mode 100644 index 0000000..8500271 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSRange.hpp @@ -0,0 +1,83 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSRange.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +struct Range +{ + static Range Make(UInteger loc, UInteger len); + + Range(UInteger loc, UInteger len); + + bool Equal(const Range& range) const; + bool LocationInRange(UInteger loc) const; + UInteger Max() const; + + UInteger location; + UInteger length; +} _NS_PACKED; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Range::Range(UInteger loc, UInteger len) + : location(loc) + , length(len) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Range NS::Range::Make(UInteger loc, UInteger len) +{ + return Range(loc, len); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Range::Equal(const Range& range) const +{ + return (location == range.location) && (length == range.length); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::Range::LocationInRange(UInteger loc) const +{ + return (!(loc < location)) && ((loc - location) < length); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Range::Max() const +{ + return location + length; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSSet.hpp b/dist/include/metal_cpp/Foundation/NSSet.hpp new file mode 100644 index 0000000..382b671 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSSet.hpp @@ -0,0 +1,87 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSSet.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSObject.hpp" +#include "NSEnumerator.hpp" + +/*****Immutable Set*******/ + +namespace NS +{ + class Set : public NS::Copying + { + public: + UInteger count() const; + Enumerator* objectEnumerator() const; + + static Set* alloc(); + + Set* init(); + Set* init(const Object* const* pObjects, UInteger count); + Set* init(const class Coder* pCoder); + + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::Set::count() const +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(count)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Enumerator* NS::Set::objectEnumerator() const +{ + return NS::Object::sendMessage*>(this, _NS_PRIVATE_SEL(objectEnumerator)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::alloc() +{ + return NS::Object::alloc(_NS_PRIVATE_CLS(NSSet)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::init() +{ + return NS::Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::init(const Object* const* pObjects, NS::UInteger count) +{ + return NS::Object::sendMessage(this, _NS_PRIVATE_SEL(initWithObjects_count_), pObjects, count); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Set* NS::Set::init(const class Coder* pCoder) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCoder_), pCoder); +} diff --git a/dist/include/metal_cpp/Foundation/NSSharedPtr.hpp b/dist/include/metal_cpp/Foundation/NSSharedPtr.hpp new file mode 100644 index 0000000..f1cf68e --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSSharedPtr.hpp @@ -0,0 +1,310 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSSharedPtr.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include +#include "NSDefines.hpp" + +namespace NS +{ +template +class SharedPtr +{ +public: + /** + * Create a new null pointer. + */ + SharedPtr(); + + /** + * Destroy this SharedPtr, decreasing the reference count. + */ + ~SharedPtr(); + + /** + * Create a new null pointer. + */ + SharedPtr(std::nullptr_t) noexcept; + + /** + * SharedPtr copy constructor. + */ + SharedPtr(const SharedPtr<_Class>& other) noexcept; + + /** + * Construction from another pointee type. + */ + template + SharedPtr(const SharedPtr<_OtherClass>& other, typename std::enable_if_t> * = nullptr) noexcept; + + /** + * SharedPtr move constructor. + */ + SharedPtr(SharedPtr<_Class>&& other) noexcept; + + /** + * Move from another pointee type. + */ + template + SharedPtr(SharedPtr<_OtherClass>&& other, typename std::enable_if_t> * = nullptr) noexcept; + + /** + * Copy assignment operator. + * Copying increases reference count. Only releases previous pointee if objects are different. + */ + SharedPtr& operator=(const SharedPtr<_Class>& other); + + /** + * Copy-assignment from different pointee. + * Copying increases reference count. Only releases previous pointee if objects are different. + */ + template + typename std::enable_if_t, SharedPtr &> + operator=(const SharedPtr<_OtherClass>& other); + + /** + * Move assignment operator. + * Move without affecting reference counts, unless pointees are equal. Moved-from object is reset to nullptr. + */ + SharedPtr& operator=(SharedPtr<_Class>&& other); + + /** + * Move-asignment from different pointee. + * Move without affecting reference counts, unless pointees are equal. Moved-from object is reset to nullptr. + */ + template + typename std::enable_if_t, SharedPtr &> + operator=(SharedPtr<_OtherClass>&& other); + + /** + * Access raw pointee. + * @warning Avoid wrapping the returned value again, as it may lead double frees unless this object becomes detached. + */ + _Class* get() const; + + /** + * Call operations directly on the pointee. + */ + _Class* operator->() const; + + /** + * Implicit cast to bool. + */ + explicit operator bool() const; + + /** + * Reset this SharedPtr to null, decreasing the reference count. + */ + void reset(); + + /** + * Detach the SharedPtr from the pointee, without decreasing the reference count. + */ + void detach(); + + template + friend SharedPtr<_OtherClass> RetainPtr(_OtherClass* ptr); + + template + friend SharedPtr<_OtherClass> TransferPtr(_OtherClass* ptr); + +private: + _Class* m_pObject; +}; + +/** + * Create a SharedPtr by retaining an existing raw pointer. + * Increases the reference count of the passed-in object. + * If the passed-in object was in an AutoreleasePool, it will be removed from it. + */ +template +_NS_INLINE NS::SharedPtr<_Class> RetainPtr(_Class* pObject) +{ + NS::SharedPtr<_Class> ret; + ret.m_pObject = pObject->retain(); + return ret; +} + +/* + * Create a SharedPtr by transfering the ownership of an existing raw pointer to SharedPtr. + * Does not increase the reference count of the passed-in pointer, it is assumed to be >= 1. + * This method does not remove objects from an AutoreleasePool. +*/ +template +_NS_INLINE NS::SharedPtr<_Class> TransferPtr(_Class* pObject) +{ + NS::SharedPtr<_Class> ret; + ret.m_pObject = pObject; + return ret; +} + +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr() + : m_pObject(nullptr) +{ +} + +template +_NS_INLINE NS::SharedPtr<_Class>::~SharedPtr<_Class>() __attribute__((no_sanitize("undefined"))) +{ + m_pObject->release(); +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(std::nullptr_t) noexcept + : m_pObject(nullptr) +{ +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(const SharedPtr<_Class>& other) noexcept + : m_pObject(other.m_pObject->retain()) +{ +} + +template +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(const SharedPtr<_OtherClass>& other, typename std::enable_if_t> *) noexcept + : m_pObject(reinterpret_cast<_Class*>(other.get()->retain())) +{ +} + +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(SharedPtr<_Class>&& other) noexcept + : m_pObject(other.m_pObject) +{ + other.m_pObject = nullptr; +} + +template +template +_NS_INLINE NS::SharedPtr<_Class>::SharedPtr(SharedPtr<_OtherClass>&& other, typename std::enable_if_t> *) noexcept + : m_pObject(reinterpret_cast<_Class*>(other.get())) +{ + other.detach(); +} + +template +_NS_INLINE _Class* NS::SharedPtr<_Class>::get() const +{ + return m_pObject; +} + +template +_NS_INLINE _Class* NS::SharedPtr<_Class>::operator->() const +{ + return m_pObject; +} + +template +_NS_INLINE NS::SharedPtr<_Class>::operator bool() const +{ + return nullptr != m_pObject; +} + +template +_NS_INLINE void NS::SharedPtr<_Class>::reset() __attribute__((no_sanitize("undefined"))) +{ + m_pObject->release(); + m_pObject = nullptr; +} + +template +_NS_INLINE void NS::SharedPtr<_Class>::detach() +{ + m_pObject = nullptr; +} + +template +_NS_INLINE NS::SharedPtr<_Class>& NS::SharedPtr<_Class>::operator=(const SharedPtr<_Class>& other) __attribute__((no_sanitize("undefined"))) +{ + _Class* pOldObject = m_pObject; + + m_pObject = other.m_pObject->retain(); + + pOldObject->release(); + + return *this; +} + +template +template +typename std::enable_if_t, NS::SharedPtr<_Class> &> +_NS_INLINE NS::SharedPtr<_Class>::operator=(const SharedPtr<_OtherClass>& other) __attribute__((no_sanitize("undefined"))) +{ + _Class* pOldObject = m_pObject; + + m_pObject = reinterpret_cast<_Class*>(other.get()->retain()); + + pOldObject->release(); + + return *this; +} + +template +_NS_INLINE NS::SharedPtr<_Class>& NS::SharedPtr<_Class>::operator=(SharedPtr<_Class>&& other) __attribute__((no_sanitize("undefined"))) +{ + if (m_pObject != other.m_pObject) + { + m_pObject->release(); + m_pObject = other.m_pObject; + } + else + { + m_pObject = other.m_pObject; + other.m_pObject->release(); + } + other.m_pObject = nullptr; + return *this; +} + +template +template +typename std::enable_if_t, NS::SharedPtr<_Class> &> +_NS_INLINE NS::SharedPtr<_Class>::operator=(SharedPtr<_OtherClass>&& other) __attribute__((no_sanitize("undefined"))) +{ + if (m_pObject != other.get()) + { + m_pObject->release(); + m_pObject = reinterpret_cast<_Class*>(other.get()); + other.detach(); + } + else + { + m_pObject = other.get(); + other.reset(); + } + return *this; +} + +template +_NS_INLINE bool operator==(const NS::SharedPtr<_ClassLhs>& lhs, const NS::SharedPtr<_ClassRhs>& rhs) +{ + return lhs.get() == rhs.get(); +} + +template +_NS_INLINE bool operator!=(const NS::SharedPtr<_ClassLhs>& lhs, const NS::SharedPtr<_ClassRhs>& rhs) +{ + return lhs.get() != rhs.get(); +} diff --git a/dist/include/metal_cpp/Foundation/NSString.hpp b/dist/include/metal_cpp/Foundation/NSString.hpp new file mode 100644 index 0000000..c48e068 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSString.hpp @@ -0,0 +1,255 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSString.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObjCRuntime.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSRange.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +_NS_ENUM(NS::UInteger, StringEncoding) { + ASCIIStringEncoding = 1, + NEXTSTEPStringEncoding = 2, + JapaneseEUCStringEncoding = 3, + UTF8StringEncoding = 4, + ISOLatin1StringEncoding = 5, + SymbolStringEncoding = 6, + NonLossyASCIIStringEncoding = 7, + ShiftJISStringEncoding = 8, + ISOLatin2StringEncoding = 9, + UnicodeStringEncoding = 10, + WindowsCP1251StringEncoding = 11, + WindowsCP1252StringEncoding = 12, + WindowsCP1253StringEncoding = 13, + WindowsCP1254StringEncoding = 14, + WindowsCP1250StringEncoding = 15, + ISO2022JPStringEncoding = 21, + MacOSRomanStringEncoding = 30, + + UTF16StringEncoding = UnicodeStringEncoding, + + UTF16BigEndianStringEncoding = 0x90000100, + UTF16LittleEndianStringEncoding = 0x94000100, + + UTF32StringEncoding = 0x8c000100, + UTF32BigEndianStringEncoding = 0x98000100, + UTF32LittleEndianStringEncoding = 0x9c000100 +}; + +_NS_OPTIONS(NS::UInteger, StringCompareOptions) { + CaseInsensitiveSearch = 1, + LiteralSearch = 2, + BackwardsSearch = 4, + AnchoredSearch = 8, + NumericSearch = 64, + DiacriticInsensitiveSearch = 128, + WidthInsensitiveSearch = 256, + ForcedOrderingSearch = 512, + RegularExpressionSearch = 1024 +}; + +using unichar = unsigned short; + +class String : public Copying +{ +public: + static String* string(); + static String* string(const String* pString); + static String* string(const char* pString, StringEncoding encoding); + + static String* alloc(); + String* init(); + String* init(const String* pString); + String* init(const char* pString, StringEncoding encoding); + String* init(void* pBytes, UInteger len, StringEncoding encoding, bool freeBuffer); + + unichar character(UInteger index) const; + UInteger length() const; + + const char* cString(StringEncoding encoding) const; + const char* utf8String() const; + UInteger maximumLengthOfBytes(StringEncoding encoding) const; + UInteger lengthOfBytes(StringEncoding encoding) const; + + bool isEqualToString(const String* pString) const; + Range rangeOfString(const String* pString, StringCompareOptions options) const; + + const char* fileSystemRepresentation() const; + + String* stringByAppendingString(const String* pString) const; + ComparisonResult caseInsensitiveCompare(const String* pString) const; +}; + +/// Create an NS::String* from a string literal. +#define MTLSTR(literal) (NS::String*)__builtin___CFStringMakeConstantString("" literal "") + +template +[[deprecated("please use MTLSTR(str)")]] constexpr const String* MakeConstantString(const char (&str)[_StringLen]) +{ + return reinterpret_cast(__CFStringMakeConstantString(str)); +} + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::string() +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSString), _NS_PRIVATE_SEL(string)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::string(const String* pString) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSString), _NS_PRIVATE_SEL(stringWithString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::string(const char* pString, StringEncoding encoding) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSString), _NS_PRIVATE_SEL(stringWithCString_encoding_), pString, encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::alloc() +{ + return Object::alloc(_NS_PRIVATE_CLS(NSString)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init() +{ + return Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init(const String* pString) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init(const char* pString, StringEncoding encoding) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithCString_encoding_), pString, encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::init(void* pBytes, UInteger len, StringEncoding encoding, bool freeBuffer) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithBytesNoCopy_length_encoding_freeWhenDone_), pBytes, len, encoding, freeBuffer); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::unichar NS::String::character(UInteger index) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(characterAtIndex_), index); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::String::length() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(length)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::String::cString(StringEncoding encoding) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(cStringUsingEncoding_), encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::String::utf8String() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(UTF8String)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::String::maximumLengthOfBytes(StringEncoding encoding) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(maximumLengthOfBytesUsingEncoding_), encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::UInteger NS::String::lengthOfBytes(StringEncoding encoding) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(lengthOfBytesUsingEncoding_), encoding); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE bool NS::String::isEqualToString(const NS::String* pString) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(isEqualToString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::Range NS::String::rangeOfString(const NS::String* pString, NS::StringCompareOptions options) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(rangeOfString_options_), pString, options); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::String::fileSystemRepresentation() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(fileSystemRepresentation)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::String* NS::String::stringByAppendingString(const String* pString) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(stringByAppendingString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::ComparisonResult NS::String::caseInsensitiveCompare(const String* pString) const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(caseInsensitiveCompare_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSTypes.hpp b/dist/include/metal_cpp/Foundation/NSTypes.hpp new file mode 100644 index 0000000..e6b723e --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSTypes.hpp @@ -0,0 +1,51 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSTypes.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" + +#include +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +using TimeInterval = double; + +using Integer = std::intptr_t; +using UInteger = std::uintptr_t; + +const Integer IntegerMax = INTPTR_MAX; +const Integer IntegerMin = INTPTR_MIN; +const UInteger UIntegerMax = UINTPTR_MAX; + +struct OperatingSystemVersion +{ + Integer majorVersion; + Integer minorVersion; + Integer patchVersion; +} _NS_PACKED; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Foundation/NSURL.hpp b/dist/include/metal_cpp/Foundation/NSURL.hpp new file mode 100644 index 0000000..d90e5d7 --- /dev/null +++ b/dist/include/metal_cpp/Foundation/NSURL.hpp @@ -0,0 +1,90 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Foundation/NSURL.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "NSDefines.hpp" +#include "NSObject.hpp" +#include "NSPrivate.hpp" +#include "NSTypes.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace NS +{ +class URL : public Copying +{ +public: + static URL* fileURLWithPath(const class String* pPath); + + static URL* alloc(); + URL* init(); + URL* init(const class String* pString); + URL* initFileURLWithPath(const class String* pPath); + + const char* fileSystemRepresentation() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::fileURLWithPath(const String* pPath) +{ + return Object::sendMessage(_NS_PRIVATE_CLS(NSURL), _NS_PRIVATE_SEL(fileURLWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::alloc() +{ + return Object::alloc(_NS_PRIVATE_CLS(NSURL)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::init() +{ + return Object::init(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::init(const String* pString) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initWithString_), pString); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE NS::URL* NS::URL::initFileURLWithPath(const String* pPath) +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(initFileURLWithPath_), pPath); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_NS_INLINE const char* NS::URL::fileSystemRepresentation() const +{ + return Object::sendMessage(this, _NS_PRIVATE_SEL(fileSystemRepresentation)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/LICENSE.txt b/dist/include/metal_cpp/LICENSE.txt new file mode 100644 index 0000000..d07f885 --- /dev/null +++ b/dist/include/metal_cpp/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright © 2024 Apple Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/dist/include/metal_cpp/Metal/MTL4AccelerationStructure.hpp b/dist/include/metal_cpp/Metal/MTL4AccelerationStructure.hpp new file mode 100644 index 0000000..1154015 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4AccelerationStructure.hpp @@ -0,0 +1,1395 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4AccelerationStructure.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAccelerationStructure.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLStageInputOutputDescriptor.hpp" + +namespace MTL4 +{ +class AccelerationStructureBoundingBoxGeometryDescriptor; +class AccelerationStructureCurveGeometryDescriptor; +class AccelerationStructureDescriptor; +class AccelerationStructureGeometryDescriptor; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor; +class AccelerationStructureMotionCurveGeometryDescriptor; +class AccelerationStructureMotionTriangleGeometryDescriptor; +class AccelerationStructureTriangleGeometryDescriptor; +class IndirectInstanceAccelerationStructureDescriptor; +class InstanceAccelerationStructureDescriptor; +class PrimitiveAccelerationStructureDescriptor; + +class AccelerationStructureDescriptor : public NS::Copying +{ +public: + static AccelerationStructureDescriptor* alloc(); + + AccelerationStructureDescriptor* init(); +}; +class AccelerationStructureGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureGeometryDescriptor* alloc(); + + bool allowDuplicateIntersectionFunctionInvocation() const; + + AccelerationStructureGeometryDescriptor* init(); + + NS::UInteger intersectionFunctionTableOffset() const; + + NS::String* label() const; + + bool opaque() const; + + BufferRange primitiveDataBuffer() const; + + NS::UInteger primitiveDataElementSize() const; + + NS::UInteger primitiveDataStride() const; + + void setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation); + + void setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset); + + void setLabel(const NS::String* label); + + void setOpaque(bool opaque); + + void setPrimitiveDataBuffer(const MTL4::BufferRange primitiveDataBuffer); + + void setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize); + + void setPrimitiveDataStride(NS::UInteger primitiveDataStride); +}; +class PrimitiveAccelerationStructureDescriptor : public NS::Copying +{ +public: + static PrimitiveAccelerationStructureDescriptor* alloc(); + + NS::Array* geometryDescriptors() const; + + PrimitiveAccelerationStructureDescriptor* init(); + + MTL::MotionBorderMode motionEndBorderMode() const; + + float motionEndTime() const; + + NS::UInteger motionKeyframeCount() const; + + MTL::MotionBorderMode motionStartBorderMode() const; + + float motionStartTime() const; + + void setGeometryDescriptors(const NS::Array* geometryDescriptors); + + void setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode); + + void setMotionEndTime(float motionEndTime); + + void setMotionKeyframeCount(NS::UInteger motionKeyframeCount); + + void setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode); + + void setMotionStartTime(float motionStartTime); +}; +class AccelerationStructureTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureTriangleGeometryDescriptor* alloc(); + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffer(const MTL4::BufferRange vertexBuffer); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + BufferRange transformationMatrixBuffer() const; + + MTL::MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + BufferRange vertexBuffer() const; + + MTL::AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureBoundingBoxGeometryDescriptor* alloc(); + + BufferRange boundingBoxBuffer() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + AccelerationStructureBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffer(const MTL4::BufferRange boundingBoxBuffer); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class AccelerationStructureMotionTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionTriangleGeometryDescriptor* alloc(); + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureMotionTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffers(const MTL4::BufferRange vertexBuffers); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + BufferRange transformationMatrixBuffer() const; + + MTL::MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + BufferRange vertexBuffers() const; + + MTL::AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionBoundingBoxGeometryDescriptor* alloc(); + + BufferRange boundingBoxBuffers() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + AccelerationStructureMotionBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffers(const MTL4::BufferRange boundingBoxBuffers); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class AccelerationStructureCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureCurveGeometryDescriptor* alloc(); + + BufferRange controlPointBuffer() const; + + NS::UInteger controlPointCount() const; + + MTL::AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + MTL::CurveBasis curveBasis() const; + + MTL::CurveEndCaps curveEndCaps() const; + + MTL::CurveType curveType() const; + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureCurveGeometryDescriptor* init(); + + BufferRange radiusBuffer() const; + + MTL::AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffer(const MTL4::BufferRange controlPointBuffer); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffer(const MTL4::BufferRange radiusBuffer); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class AccelerationStructureMotionCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionCurveGeometryDescriptor* alloc(); + + BufferRange controlPointBuffers() const; + + NS::UInteger controlPointCount() const; + + MTL::AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + MTL::CurveBasis curveBasis() const; + + MTL::CurveEndCaps curveEndCaps() const; + + MTL::CurveType curveType() const; + + BufferRange indexBuffer() const; + + MTL::IndexType indexType() const; + + AccelerationStructureMotionCurveGeometryDescriptor* init(); + + BufferRange radiusBuffers() const; + + MTL::AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffers(const MTL4::BufferRange controlPointBuffers); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL4::BufferRange indexBuffer); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffers(const MTL4::BufferRange radiusBuffers); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class InstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static InstanceAccelerationStructureDescriptor* alloc(); + + InstanceAccelerationStructureDescriptor* init(); + + NS::UInteger instanceCount() const; + + BufferRange instanceDescriptorBuffer() const; + + NS::UInteger instanceDescriptorStride() const; + + MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MTL::MatrixLayout instanceTransformationMatrixLayout() const; + + BufferRange motionTransformBuffer() const; + + NS::UInteger motionTransformCount() const; + + NS::UInteger motionTransformStride() const; + + MTL::TransformType motionTransformType() const; + + void setInstanceCount(NS::UInteger instanceCount); + + void setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer); + + void setMotionTransformCount(NS::UInteger motionTransformCount); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; +class IndirectInstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static IndirectInstanceAccelerationStructureDescriptor* alloc(); + + IndirectInstanceAccelerationStructureDescriptor* init(); + + BufferRange instanceCountBuffer() const; + + BufferRange instanceDescriptorBuffer() const; + + NS::UInteger instanceDescriptorStride() const; + + MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MTL::MatrixLayout instanceTransformationMatrixLayout() const; + + NS::UInteger maxInstanceCount() const; + + NS::UInteger maxMotionTransformCount() const; + + BufferRange motionTransformBuffer() const; + + BufferRange motionTransformCountBuffer() const; + + NS::UInteger motionTransformStride() const; + + MTL::TransformType motionTransformType() const; + + void setInstanceCountBuffer(const MTL4::BufferRange instanceCountBuffer); + + void setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setMaxInstanceCount(NS::UInteger maxInstanceCount); + + void setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount); + + void setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer); + + void setMotionTransformCountBuffer(const MTL4::BufferRange motionTransformCountBuffer); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; + +} +_MTL_INLINE MTL4::AccelerationStructureDescriptor* MTL4::AccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL4::AccelerationStructureDescriptor* MTL4::AccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::AccelerationStructureGeometryDescriptor* MTL4::AccelerationStructureGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureGeometryDescriptor)); +} + +_MTL_INLINE bool MTL4::AccelerationStructureGeometryDescriptor::allowDuplicateIntersectionFunctionInvocation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowDuplicateIntersectionFunctionInvocation)); +} + +_MTL_INLINE MTL4::AccelerationStructureGeometryDescriptor* MTL4::AccelerationStructureGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureGeometryDescriptor::intersectionFunctionTableOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(intersectionFunctionTableOffset)); +} + +_MTL_INLINE NS::String* MTL4::AccelerationStructureGeometryDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL4::AccelerationStructureGeometryDescriptor::opaque() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(opaque)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureGeometryDescriptor::primitiveDataBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureGeometryDescriptor::primitiveDataElementSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataElementSize)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureGeometryDescriptor::primitiveDataStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataStride)); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowDuplicateIntersectionFunctionInvocation_), allowDuplicateIntersectionFunctionInvocation); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTableOffset_), intersectionFunctionTableOffset); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setOpaque(bool opaque) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaque_), opaque); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setPrimitiveDataBuffer(const MTL4::BufferRange primitiveDataBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataBuffer_), primitiveDataBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataElementSize_), primitiveDataElementSize); +} + +_MTL_INLINE void MTL4::AccelerationStructureGeometryDescriptor::setPrimitiveDataStride(NS::UInteger primitiveDataStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataStride_), primitiveDataStride); +} + +_MTL_INLINE MTL4::PrimitiveAccelerationStructureDescriptor* MTL4::PrimitiveAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PrimitiveAccelerationStructureDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::PrimitiveAccelerationStructureDescriptor::geometryDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(geometryDescriptors)); +} + +_MTL_INLINE MTL4::PrimitiveAccelerationStructureDescriptor* MTL4::PrimitiveAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::MotionBorderMode MTL4::PrimitiveAccelerationStructureDescriptor::motionEndBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndBorderMode)); +} + +_MTL_INLINE float MTL4::PrimitiveAccelerationStructureDescriptor::motionEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndTime)); +} + +_MTL_INLINE NS::UInteger MTL4::PrimitiveAccelerationStructureDescriptor::motionKeyframeCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionKeyframeCount)); +} + +_MTL_INLINE MTL::MotionBorderMode MTL4::PrimitiveAccelerationStructureDescriptor::motionStartBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartBorderMode)); +} + +_MTL_INLINE float MTL4::PrimitiveAccelerationStructureDescriptor::motionStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartTime)); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setGeometryDescriptors(const NS::Array* geometryDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGeometryDescriptors_), geometryDescriptors); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndBorderMode_), motionEndBorderMode); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionEndTime(float motionEndTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndTime_), motionEndTime); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionKeyframeCount(NS::UInteger motionKeyframeCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionKeyframeCount_), motionKeyframeCount); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartBorderMode_), motionStartBorderMode); +} + +_MTL_INLINE void MTL4::PrimitiveAccelerationStructureDescriptor::setMotionStartTime(float motionStartTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartTime_), motionStartTime); +} + +_MTL_INLINE MTL4::AccelerationStructureTriangleGeometryDescriptor* MTL4::AccelerationStructureTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureTriangleGeometryDescriptor* MTL4::AccelerationStructureTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setVertexBuffer(const MTL4::BufferRange vertexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_), vertexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureTriangleGeometryDescriptor::vertexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffer)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxBuffer(const MTL4::BufferRange boundingBoxBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffer_), boundingBoxBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionTriangleGeometryDescriptor* MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureMotionTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionTriangleGeometryDescriptor* MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL4::BufferRange transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexBuffers(const MTL4::BufferRange vertexBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_), vertexBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::vertexBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureMotionBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffers)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxBuffers(const MTL4::BufferRange boundingBoxBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffers_), boundingBoxBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL4::AccelerationStructureCurveGeometryDescriptor* MTL4::AccelerationStructureCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureCurveGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL4::AccelerationStructureCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL4::AccelerationStructureCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL4::AccelerationStructureCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureCurveGeometryDescriptor* MTL4::AccelerationStructureCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureCurveGeometryDescriptor::radiusBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffer)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointBuffer(const MTL4::BufferRange controlPointBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffer_), controlPointBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setRadiusBuffer(const MTL4::BufferRange radiusBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffer_), radiusBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionCurveGeometryDescriptor* MTL4::AccelerationStructureMotionCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4AccelerationStructureMotionCurveGeometryDescriptor)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffers)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL4::AccelerationStructureMotionCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL4::AccelerationStructureMotionCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL4::AccelerationStructureMotionCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE MTL::IndexType MTL4::AccelerationStructureMotionCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL4::AccelerationStructureMotionCurveGeometryDescriptor* MTL4::AccelerationStructureMotionCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::BufferRange MTL4::AccelerationStructureMotionCurveGeometryDescriptor::radiusBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL4::AccelerationStructureMotionCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL4::AccelerationStructureMotionCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointBuffers(const MTL4::BufferRange controlPointBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffers_), controlPointBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setIndexBuffer(const MTL4::BufferRange indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusBuffers(const MTL4::BufferRange radiusBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffers_), radiusBuffers); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL4::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL4::InstanceAccelerationStructureDescriptor* MTL4::InstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4InstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL4::InstanceAccelerationStructureDescriptor* MTL4::InstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::instanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::InstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL4::InstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::InstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::InstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::motionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCount)); +} + +_MTL_INLINE NS::UInteger MTL4::InstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL4::InstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceCount(NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCount_), instanceCount); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformCount(NS::UInteger motionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCount_), motionTransformCount); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL4::InstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +_MTL_INLINE MTL4::IndirectInstanceAccelerationStructureDescriptor* MTL4::IndirectInstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4IndirectInstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL4::IndirectInstanceAccelerationStructureDescriptor* MTL4::IndirectInstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCountBuffer)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL4::IndirectInstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::maxInstanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxInstanceCount)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::maxMotionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxMotionTransformCount)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE MTL4::BufferRange MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCountBuffer)); +} + +_MTL_INLINE NS::UInteger MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL4::IndirectInstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceCountBuffer(const MTL4::BufferRange instanceCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCountBuffer_), instanceCountBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL4::BufferRange instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMaxInstanceCount(NS::UInteger maxInstanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxInstanceCount_), maxInstanceCount); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxMotionTransformCount_), maxMotionTransformCount); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL4::BufferRange motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformCountBuffer(const MTL4::BufferRange motionTransformCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCountBuffer_), motionTransformCountBuffer); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL4::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} diff --git a/dist/include/metal_cpp/Metal/MTL4Archive.hpp b/dist/include/metal_cpp/Metal/MTL4Archive.hpp new file mode 100644 index 0000000..c83ef63 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4Archive.hpp @@ -0,0 +1,93 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4Archive.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class ComputePipelineState; +class RenderPipelineState; +} + +namespace MTL4 +{ +class BinaryFunction; +class BinaryFunctionDescriptor; +class ComputePipelineDescriptor; +class PipelineDescriptor; +class PipelineStageDynamicLinkingDescriptor; +class RenderPipelineDynamicLinkingDescriptor; + +class Archive : public NS::Referencing +{ +public: + NS::String* label() const; + + BinaryFunction* newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, NS::Error** error); + + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, NS::Error** error); + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error); + + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, NS::Error** error); + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error); + + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE NS::String* MTL4::Archive::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::BinaryFunction* MTL4::Archive::newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryFunctionWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Archive::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Archive::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_error_), descriptor, dynamicLinkingDescriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Archive::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Archive::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_error_), descriptor, dynamicLinkingDescriptor, error); +} + +_MTL_INLINE void MTL4::Archive::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/dist/include/metal_cpp/Metal/MTL4ArgumentTable.hpp b/dist/include/metal_cpp/Metal/MTL4ArgumentTable.hpp new file mode 100644 index 0000000..7788ed9 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4ArgumentTable.hpp @@ -0,0 +1,187 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4ArgumentTable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +} + +namespace MTL4 +{ +class ArgumentTableDescriptor : public NS::Copying +{ +public: + static ArgumentTableDescriptor* alloc(); + + ArgumentTableDescriptor* init(); + bool initializeBindings() const; + + NS::String* label() const; + + NS::UInteger maxBufferBindCount() const; + + NS::UInteger maxSamplerStateBindCount() const; + + NS::UInteger maxTextureBindCount() const; + + void setInitializeBindings(bool initializeBindings); + + void setLabel(const NS::String* label); + + void setMaxBufferBindCount(NS::UInteger maxBufferBindCount); + + void setMaxSamplerStateBindCount(NS::UInteger maxSamplerStateBindCount); + + void setMaxTextureBindCount(NS::UInteger maxTextureBindCount); + + void setSupportAttributeStrides(bool supportAttributeStrides); + bool supportAttributeStrides() const; +}; +class ArgumentTable : public NS::Referencing +{ +public: + MTL::Device* device() const; + + NS::String* label() const; + + void setAddress(MTL::GPUAddress gpuAddress, NS::UInteger bindingIndex); + void setAddress(MTL::GPUAddress gpuAddress, NS::UInteger stride, NS::UInteger bindingIndex); + + void setResource(MTL::ResourceID resourceID, NS::UInteger bindingIndex); + + void setSamplerState(MTL::ResourceID resourceID, NS::UInteger bindingIndex); + + void setTexture(MTL::ResourceID resourceID, NS::UInteger bindingIndex); +}; + +} +_MTL_INLINE MTL4::ArgumentTableDescriptor* MTL4::ArgumentTableDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4ArgumentTableDescriptor)); +} + +_MTL_INLINE MTL4::ArgumentTableDescriptor* MTL4::ArgumentTableDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL4::ArgumentTableDescriptor::initializeBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initializeBindings)); +} + +_MTL_INLINE NS::String* MTL4::ArgumentTableDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL4::ArgumentTableDescriptor::maxBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL4::ArgumentTableDescriptor::maxSamplerStateBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxSamplerStateBindCount)); +} + +_MTL_INLINE NS::UInteger MTL4::ArgumentTableDescriptor::maxTextureBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTextureBindCount)); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setInitializeBindings(bool initializeBindings) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInitializeBindings_), initializeBindings); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setMaxBufferBindCount(NS::UInteger maxBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxBufferBindCount_), maxBufferBindCount); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setMaxSamplerStateBindCount(NS::UInteger maxSamplerStateBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxSamplerStateBindCount_), maxSamplerStateBindCount); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setMaxTextureBindCount(NS::UInteger maxTextureBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTextureBindCount_), maxTextureBindCount); +} + +_MTL_INLINE void MTL4::ArgumentTableDescriptor::setSupportAttributeStrides(bool supportAttributeStrides) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAttributeStrides_), supportAttributeStrides); +} + +_MTL_INLINE bool MTL4::ArgumentTableDescriptor::supportAttributeStrides() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAttributeStrides)); +} + +_MTL_INLINE MTL::Device* MTL4::ArgumentTable::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::ArgumentTable::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::ArgumentTable::setAddress(MTL::GPUAddress gpuAddress, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAddress_atIndex_), gpuAddress, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setAddress(MTL::GPUAddress gpuAddress, NS::UInteger stride, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAddress_attributeStride_atIndex_), gpuAddress, stride, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setResource(MTL::ResourceID resourceID, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResource_atBufferIndex_), resourceID, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setSamplerState(MTL::ResourceID resourceID, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_atIndex_), resourceID, bindingIndex); +} + +_MTL_INLINE void MTL4::ArgumentTable::setTexture(MTL::ResourceID resourceID, NS::UInteger bindingIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_atIndex_), resourceID, bindingIndex); +} diff --git a/dist/include/metal_cpp/Metal/MTL4BinaryFunction.hpp b/dist/include/metal_cpp/Metal/MTL4BinaryFunction.hpp new file mode 100644 index 0000000..30d90a6 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4BinaryFunction.hpp @@ -0,0 +1,50 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4BinaryFunction.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLLibrary.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ + +class BinaryFunction : public NS::Referencing +{ +public: + MTL::FunctionType functionType() const; + + NS::String* name() const; +}; + +} + +_MTL_INLINE MTL::FunctionType MTL4::BinaryFunction::functionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionType)); +} + +_MTL_INLINE NS::String* MTL4::BinaryFunction::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4BinaryFunctionDescriptor.hpp b/dist/include/metal_cpp/Metal/MTL4BinaryFunctionDescriptor.hpp new file mode 100644 index 0000000..ce173ce --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4BinaryFunctionDescriptor.hpp @@ -0,0 +1,97 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4BinaryFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class BinaryFunctionDescriptor; +class FunctionDescriptor; + +_MTL_OPTIONS(NS::UInteger, BinaryFunctionOptions) { + BinaryFunctionOptionNone = 0, + BinaryFunctionOptionPipelineIndependent = 1 << 1, +}; + +class BinaryFunctionDescriptor : public NS::Copying +{ +public: + static BinaryFunctionDescriptor* alloc(); + + FunctionDescriptor* functionDescriptor() const; + + BinaryFunctionDescriptor* init(); + + NS::String* name() const; + + BinaryFunctionOptions options() const; + + void setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor); + + void setName(const NS::String* name); + + void setOptions(MTL4::BinaryFunctionOptions options); +}; + +} +_MTL_INLINE MTL4::BinaryFunctionDescriptor* MTL4::BinaryFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4BinaryFunctionDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::BinaryFunctionDescriptor::functionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptor)); +} + +_MTL_INLINE MTL4::BinaryFunctionDescriptor* MTL4::BinaryFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::BinaryFunctionDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL4::BinaryFunctionOptions MTL4::BinaryFunctionDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL4::BinaryFunctionDescriptor::setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptor_), functionDescriptor); +} + +_MTL_INLINE void MTL4::BinaryFunctionDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE void MTL4::BinaryFunctionDescriptor::setOptions(MTL4::BinaryFunctionOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} diff --git a/dist/include/metal_cpp/Metal/MTL4CommandAllocator.hpp b/dist/include/metal_cpp/Metal/MTL4CommandAllocator.hpp new file mode 100644 index 0000000..a36b050 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4CommandAllocator.hpp @@ -0,0 +1,100 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandAllocator.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Device; +} + +namespace MTL4 +{ + +class CommandAllocatorDescriptor : public NS::Copying +{ +public: + static CommandAllocatorDescriptor* alloc(); + + CommandAllocatorDescriptor* init(); + + NS::String* label() const; + void setLabel(const NS::String* label); +}; + +class CommandAllocator : public NS::Referencing +{ +public: + uint64_t allocatedSize(); + + MTL::Device* device() const; + + NS::String* label() const; + + void reset(); +}; + +} + +_MTL_INLINE MTL4::CommandAllocatorDescriptor* MTL4::CommandAllocatorDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommandAllocatorDescriptor)); +} + +_MTL_INLINE MTL4::CommandAllocatorDescriptor* MTL4::CommandAllocatorDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::CommandAllocatorDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandAllocatorDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE uint64_t MTL4::CommandAllocator::allocatedSize() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} + +_MTL_INLINE MTL::Device* MTL4::CommandAllocator::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::CommandAllocator::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandAllocator::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4CommandBuffer.hpp b/dist/include/metal_cpp/Metal/MTL4CommandBuffer.hpp new file mode 100644 index 0000000..a69cc94 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4CommandBuffer.hpp @@ -0,0 +1,193 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4RenderCommandEncoder.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class CommandAllocator; +class CommandBufferOptions; +class ComputeCommandEncoder; +class CounterHeap; +class MachineLearningCommandEncoder; +class RenderCommandEncoder; +class RenderPassDescriptor; +} + +namespace MTL +{ +class Device; +class Fence; +class LogState; +class ResidencySet; +} + +namespace MTL4 +{ +class CommandBufferOptions : public NS::Copying +{ +public: + static CommandBufferOptions* alloc(); + + CommandBufferOptions* init(); + + MTL::LogState* logState() const; + void setLogState(const MTL::LogState* logState); +}; +class CommandBuffer : public NS::Referencing +{ +public: + void beginCommandBuffer(const MTL4::CommandAllocator* allocator); + void beginCommandBuffer(const MTL4::CommandAllocator* allocator, const MTL4::CommandBufferOptions* options); + + ComputeCommandEncoder* computeCommandEncoder(); + + MTL::Device* device() const; + + void endCommandBuffer(); + + NS::String* label() const; + + MachineLearningCommandEncoder* machineLearningCommandEncoder(); + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + RenderCommandEncoder* renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor); + RenderCommandEncoder* renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor, MTL4::RenderEncoderOptions options); + + void resolveCounterHeap(const MTL4::CounterHeap* counterHeap, NS::Range range, const MTL4::BufferRange bufferRange, const MTL::Fence* fenceToWait, const MTL::Fence* fenceToUpdate); + + void setLabel(const NS::String* label); + + void useResidencySet(const MTL::ResidencySet* residencySet); + void useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void writeTimestampIntoHeap(const MTL4::CounterHeap* counterHeap, NS::UInteger index); +}; + +} +_MTL_INLINE MTL4::CommandBufferOptions* MTL4::CommandBufferOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommandBufferOptions)); +} + +_MTL_INLINE MTL4::CommandBufferOptions* MTL4::CommandBufferOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogState* MTL4::CommandBufferOptions::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE void MTL4::CommandBufferOptions::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} + +_MTL_INLINE void MTL4::CommandBuffer::beginCommandBuffer(const MTL4::CommandAllocator* allocator) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(beginCommandBufferWithAllocator_), allocator); +} + +_MTL_INLINE void MTL4::CommandBuffer::beginCommandBuffer(const MTL4::CommandAllocator* allocator, const MTL4::CommandBufferOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(beginCommandBufferWithAllocator_options_), allocator, options); +} + +_MTL_INLINE MTL4::ComputeCommandEncoder* MTL4::CommandBuffer::computeCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoder)); +} + +_MTL_INLINE MTL::Device* MTL4::CommandBuffer::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL4::CommandBuffer::endCommandBuffer() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endCommandBuffer)); +} + +_MTL_INLINE NS::String* MTL4::CommandBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::MachineLearningCommandEncoder* MTL4::CommandBuffer::machineLearningCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(machineLearningCommandEncoder)); +} + +_MTL_INLINE void MTL4::CommandBuffer::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL4::CommandBuffer::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE MTL4::RenderCommandEncoder* MTL4::CommandBuffer::renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoderWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL4::RenderCommandEncoder* MTL4::CommandBuffer::renderCommandEncoder(const MTL4::RenderPassDescriptor* descriptor, MTL4::RenderEncoderOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoderWithDescriptor_options_), descriptor, options); +} + +_MTL_INLINE void MTL4::CommandBuffer::resolveCounterHeap(const MTL4::CounterHeap* counterHeap, NS::Range range, const MTL4::BufferRange bufferRange, const MTL::Fence* fenceToWait, const MTL::Fence* fenceToUpdate) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounterHeap_withRange_intoBuffer_waitFence_updateFence_), counterHeap, range, bufferRange, fenceToWait, fenceToUpdate); +} + +_MTL_INLINE void MTL4::CommandBuffer::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CommandBuffer::useResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySet_), residencySet); +} + +_MTL_INLINE void MTL4::CommandBuffer::useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL4::CommandBuffer::writeTimestampIntoHeap(const MTL4::CounterHeap* counterHeap, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeTimestampIntoHeap_atIndex_), counterHeap, index); +} diff --git a/dist/include/metal_cpp/Metal/MTL4CommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTL4CommandEncoder.hpp new file mode 100644 index 0000000..2336021 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4CommandEncoder.hpp @@ -0,0 +1,134 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class CommandBuffer; +} + +namespace MTL +{ +class Fence; +} + +namespace MTL4 +{ +_MTL_OPTIONS(NS::UInteger, VisibilityOptions) { + VisibilityOptionNone = 0, + VisibilityOptionDevice = 1, + VisibilityOptionResourceAlias = 1 << 1, +}; + +class CommandEncoder : public NS::Referencing +{ +public: + void barrierAfterEncoderStages(MTL::Stages afterEncoderStages, MTL::Stages beforeEncoderStages, MTL4::VisibilityOptions visibilityOptions); + + void barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages, MTL4::VisibilityOptions visibilityOptions); + + void barrierAfterStages(MTL::Stages afterStages, MTL::Stages beforeQueueStages, MTL4::VisibilityOptions visibilityOptions); + + CommandBuffer* commandBuffer() const; + + void endEncoding(); + + void insertDebugSignpost(const NS::String* string); + + NS::String* label() const; + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + void setLabel(const NS::String* label); + + void updateFence(const MTL::Fence* fence, MTL::Stages afterEncoderStages); + + void waitForFence(const MTL::Fence* fence, MTL::Stages beforeEncoderStages); +}; + +} +_MTL_INLINE void MTL4::CommandEncoder::barrierAfterEncoderStages(MTL::Stages afterEncoderStages, MTL::Stages beforeEncoderStages, MTL4::VisibilityOptions visibilityOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterEncoderStages_beforeEncoderStages_visibilityOptions_), afterEncoderStages, beforeEncoderStages, visibilityOptions); +} + +_MTL_INLINE void MTL4::CommandEncoder::barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages, MTL4::VisibilityOptions visibilityOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterQueueStages_beforeStages_visibilityOptions_), afterQueueStages, beforeStages, visibilityOptions); +} + +_MTL_INLINE void MTL4::CommandEncoder::barrierAfterStages(MTL::Stages afterStages, MTL::Stages beforeQueueStages, MTL4::VisibilityOptions visibilityOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterStages_beforeQueueStages_visibilityOptions_), afterStages, beforeQueueStages, visibilityOptions); +} + +_MTL_INLINE MTL4::CommandBuffer* MTL4::CommandEncoder::commandBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBuffer)); +} + +_MTL_INLINE void MTL4::CommandEncoder::endEncoding() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endEncoding)); +} + +_MTL_INLINE void MTL4::CommandEncoder::insertDebugSignpost(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(insertDebugSignpost_), string); +} + +_MTL_INLINE NS::String* MTL4::CommandEncoder::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandEncoder::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL4::CommandEncoder::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE void MTL4::CommandEncoder::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CommandEncoder::updateFence(const MTL::Fence* fence, MTL::Stages afterEncoderStages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_afterEncoderStages_), fence, afterEncoderStages); +} + +_MTL_INLINE void MTL4::CommandEncoder::waitForFence(const MTL::Fence* fence, MTL::Stages beforeEncoderStages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_beforeEncoderStages_), fence, beforeEncoderStages); +} diff --git a/dist/include/metal_cpp/Metal/MTL4CommandQueue.hpp b/dist/include/metal_cpp/Metal/MTL4CommandQueue.hpp new file mode 100644 index 0000000..cbd21c7 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4CommandQueue.hpp @@ -0,0 +1,283 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommandQueue.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommitFeedback.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResourceStateCommandEncoder.hpp" +#include "MTLTypes.hpp" +#include +#include + +namespace MTL +{ +class Buffer; +class Device; +class Drawable; +class Event; +class Heap; +class ResidencySet; +class Texture; +} + +namespace MTL4 +{ +class CommandBuffer; +class CommandQueueDescriptor; +class CommitOptions; +struct CopySparseBufferMappingOperation; +struct CopySparseTextureMappingOperation; +struct UpdateSparseBufferMappingOperation; +struct UpdateSparseTextureMappingOperation; +_MTL_ENUM(NS::Integer, CommandQueueError) { + CommandQueueErrorNone = 0, + CommandQueueErrorTimeout = 1, + CommandQueueErrorNotPermitted = 2, + CommandQueueErrorOutOfMemory = 3, + CommandQueueErrorDeviceRemoved = 4, + CommandQueueErrorAccessRevoked = 5, + CommandQueueErrorInternal = 6, +}; + +struct UpdateSparseTextureMappingOperation +{ + MTL::SparseTextureMappingMode mode; + MTL::Region textureRegion; + NS::UInteger textureLevel; + NS::UInteger textureSlice; + NS::UInteger heapOffset; +} _MTL_PACKED; + +struct CopySparseTextureMappingOperation +{ + MTL::Region sourceRegion; + NS::UInteger sourceLevel; + NS::UInteger sourceSlice; + MTL::Origin destinationOrigin; + NS::UInteger destinationLevel; + NS::UInteger destinationSlice; +} _MTL_PACKED; + +struct UpdateSparseBufferMappingOperation +{ + MTL::SparseTextureMappingMode mode; + NS::Range bufferRange; + NS::UInteger heapOffset; +} _MTL_PACKED; + +struct CopySparseBufferMappingOperation +{ + NS::Range sourceRange; + NS::UInteger destinationOffset; +} _MTL_PACKED; + +class CommitOptions : public NS::Referencing +{ +public: + void addFeedbackHandler(const MTL4::CommitFeedbackHandler block); + void addFeedbackHandler(const MTL4::CommitFeedbackHandlerFunction& function); + + static CommitOptions* alloc(); + + CommitOptions* init(); +}; +class CommandQueueDescriptor : public NS::Copying +{ +public: + static CommandQueueDescriptor* alloc(); + + dispatch_queue_t feedbackQueue() const; + + CommandQueueDescriptor* init(); + + NS::String* label() const; + + void setFeedbackQueue(const dispatch_queue_t feedbackQueue); + + void setLabel(const NS::String* label); +}; +class CommandQueue : public NS::Referencing +{ +public: + void addResidencySet(const MTL::ResidencySet* residencySet); + void addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count); + void commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count, const MTL4::CommitOptions* options); + + void copyBufferMappingsFromBuffer(const MTL::Buffer* sourceBuffer, const MTL::Buffer* destinationBuffer, const MTL4::CopySparseBufferMappingOperation* operations, NS::UInteger count); + + void copyTextureMappingsFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture, const MTL4::CopySparseTextureMappingOperation* operations, NS::UInteger count); + + MTL::Device* device() const; + + NS::String* label() const; + + void removeResidencySet(const MTL::ResidencySet* residencySet); + void removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void signalDrawable(const MTL::Drawable* drawable); + + void signalEvent(const MTL::Event* event, uint64_t value); + + void updateBufferMappings(const MTL::Buffer* buffer, const MTL::Heap* heap, const MTL4::UpdateSparseBufferMappingOperation* operations, NS::UInteger count); + + void updateTextureMappings(const MTL::Texture* texture, const MTL::Heap* heap, const MTL4::UpdateSparseTextureMappingOperation* operations, NS::UInteger count); + + void wait(const MTL::Event* event, uint64_t value); + void wait(const MTL::Drawable* drawable); +}; + +} + +_MTL_INLINE void MTL4::CommitOptions::addFeedbackHandler(const MTL4::CommitFeedbackHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addFeedbackHandler_), block); +} + +_MTL_INLINE void MTL4::CommitOptions::addFeedbackHandler(const MTL4::CommitFeedbackHandlerFunction& function) +{ + __block MTL4::CommitFeedbackHandlerFunction blockFunction = function; + addFeedbackHandler(^(MTL4::CommitFeedback* pFeedback) { blockFunction(pFeedback); }); +} + +_MTL_INLINE MTL4::CommitOptions* MTL4::CommitOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommitOptions)); +} + +_MTL_INLINE MTL4::CommitOptions* MTL4::CommitOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::CommandQueueDescriptor* MTL4::CommandQueueDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CommandQueueDescriptor)); +} + +_MTL_INLINE dispatch_queue_t MTL4::CommandQueueDescriptor::feedbackQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(feedbackQueue)); +} + +_MTL_INLINE MTL4::CommandQueueDescriptor* MTL4::CommandQueueDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::CommandQueueDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandQueueDescriptor::setFeedbackQueue(const dispatch_queue_t feedbackQueue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFeedbackQueue_), feedbackQueue); +} + +_MTL_INLINE void MTL4::CommandQueueDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CommandQueue::addResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySet_), residencySet); +} + +_MTL_INLINE void MTL4::CommandQueue::addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL4::CommandQueue::commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit_count_), commandBuffers, count); +} + +_MTL_INLINE void MTL4::CommandQueue::commit(const MTL4::CommandBuffer* const commandBuffers[], NS::UInteger count, const MTL4::CommitOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit_count_options_), commandBuffers, count, options); +} + +_MTL_INLINE void MTL4::CommandQueue::copyBufferMappingsFromBuffer(const MTL::Buffer* sourceBuffer, const MTL::Buffer* destinationBuffer, const MTL4::CopySparseBufferMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyBufferMappingsFromBuffer_toBuffer_operations_count_), sourceBuffer, destinationBuffer, operations, count); +} + +_MTL_INLINE void MTL4::CommandQueue::copyTextureMappingsFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture, const MTL4::CopySparseTextureMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyTextureMappingsFromTexture_toTexture_operations_count_), sourceTexture, destinationTexture, operations, count); +} + +_MTL_INLINE MTL::Device* MTL4::CommandQueue::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::CommandQueue::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL4::CommandQueue::removeResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySet_), residencySet); +} + +_MTL_INLINE void MTL4::CommandQueue::removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL4::CommandQueue::signalDrawable(const MTL::Drawable* drawable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(signalDrawable_), drawable); +} + +_MTL_INLINE void MTL4::CommandQueue::signalEvent(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(signalEvent_value_), event, value); +} + +_MTL_INLINE void MTL4::CommandQueue::updateBufferMappings(const MTL::Buffer* buffer, const MTL::Heap* heap, const MTL4::UpdateSparseBufferMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateBufferMappings_heap_operations_count_), buffer, heap, operations, count); +} + +_MTL_INLINE void MTL4::CommandQueue::updateTextureMappings(const MTL::Texture* texture, const MTL::Heap* heap, const MTL4::UpdateSparseTextureMappingOperation* operations, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMappings_heap_operations_count_), texture, heap, operations, count); +} + +_MTL_INLINE void MTL4::CommandQueue::wait(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForEvent_value_), event, value); +} + +_MTL_INLINE void MTL4::CommandQueue::wait(const MTL::Drawable* drawable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForDrawable_), drawable); +} diff --git a/dist/include/metal_cpp/Metal/MTL4CommitFeedback.hpp b/dist/include/metal_cpp/Metal/MTL4CommitFeedback.hpp new file mode 100644 index 0000000..6b8181f --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4CommitFeedback.hpp @@ -0,0 +1,62 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CommitFeedback.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +#include + +namespace MTL4 +{ +class CommitFeedback; + +using CommitFeedbackHandler = void (^)(MTL4::CommitFeedback*); +using CommitFeedbackHandlerFunction = std::function; + +class CommitFeedback : public NS::Referencing +{ +public: + CFTimeInterval GPUEndTime() const; + + CFTimeInterval GPUStartTime() const; + + NS::Error* error() const; +}; + +} +_MTL_INLINE CFTimeInterval MTL4::CommitFeedback::GPUEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUEndTime)); +} + +_MTL_INLINE CFTimeInterval MTL4::CommitFeedback::GPUStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUStartTime)); +} + +_MTL_INLINE NS::Error* MTL4::CommitFeedback::error() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(error)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4Compiler.hpp b/dist/include/metal_cpp/Metal/MTL4Compiler.hpp new file mode 100644 index 0000000..94249b2 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4Compiler.hpp @@ -0,0 +1,345 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4Compiler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLDevice.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +#include + +namespace MTL4 +{ +class BinaryFunction; +class BinaryFunctionDescriptor; +class CompilerDescriptor; +class CompilerTask; +class CompilerTaskOptions; +class ComputePipelineDescriptor; +class LibraryDescriptor; +class MachineLearningPipelineDescriptor; +class MachineLearningPipelineState; +class PipelineDataSetSerializer; +class PipelineDescriptor; +class PipelineStageDynamicLinkingDescriptor; +class RenderPipelineDynamicLinkingDescriptor; +} + +namespace MTL +{ +class ComputePipelineState; +class Device; +class DynamicLibrary; +class Library; +class RenderPipelineState; + +using NewDynamicLibraryCompletionHandler = void (^)(MTL::DynamicLibrary*, NS::Error*); +using NewDynamicLibraryCompletionHandlerFunction = std::function; +} + +namespace MTL4 +{ +using NewComputePipelineStateCompletionHandler = void (^)(MTL::ComputePipelineState*, NS::Error*); +using NewComputePipelineStateCompletionHandlerFunction = std::function; +using NewRenderPipelineStateCompletionHandler = void (^)(MTL::RenderPipelineState*, NS::Error*); +using NewRenderPipelineStateCompletionHandlerFunction = std::function; +using NewBinaryFunctionCompletionHandler = void (^)(MTL4::BinaryFunction*, NS::Error*); +using NewBinaryFunctionCompletionHandlerFunction = std::function; +using NewMachineLearningPipelineStateCompletionHandler = void (^)(MTL4::MachineLearningPipelineState*, NS::Error*); +using NewMachineLearningPipelineStateCompletionHandlerFunction = std::function; + +class CompilerDescriptor : public NS::Copying +{ +public: + static CompilerDescriptor* alloc(); + + CompilerDescriptor* init(); + + NS::String* label() const; + + PipelineDataSetSerializer* pipelineDataSetSerializer() const; + + void setLabel(const NS::String* label); + + void setPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializer* pipelineDataSetSerializer); +}; +class CompilerTaskOptions : public NS::Copying +{ +public: + static CompilerTaskOptions* alloc(); + + CompilerTaskOptions* init(); + + NS::Array* lookupArchives() const; + void setLookupArchives(const NS::Array* lookupArchives); +}; +class Compiler : public NS::Referencing +{ +public: + MTL::Device* device() const; + + NS::String* label() const; + + BinaryFunction* newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + CompilerTask* newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL4::NewBinaryFunctionCompletionHandler completionHandler); + + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + MTL::ComputePipelineState* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + CompilerTask* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler); + CompilerTask* newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler); + CompilerTask* newComputePipelineState(const MTL4::ComputePipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewComputePipelineStateCompletionHandlerFunction& function); + + MTL::DynamicLibrary* newDynamicLibrary(const MTL::Library* library, NS::Error** error); + MTL::DynamicLibrary* newDynamicLibrary(const NS::URL* url, NS::Error** error); + CompilerTask* newDynamicLibrary(const MTL::Library* library, const MTL::NewDynamicLibraryCompletionHandler completionHandler); + CompilerTask* newDynamicLibrary(const NS::URL* url, const MTL::NewDynamicLibraryCompletionHandler completionHandler); + CompilerTask* newDynamicLibrary(const MTL::Library* pLibrary, const MTL::NewDynamicLibraryCompletionHandlerFunction& function); + CompilerTask* newDynamicLibrary(const NS::URL* pURL, const MTL::NewDynamicLibraryCompletionHandlerFunction& function); + + MTL::Library* newLibrary(const MTL4::LibraryDescriptor* descriptor, NS::Error** error); + CompilerTask* newLibrary(const MTL4::LibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler); + CompilerTask* newLibrary(const MTL4::LibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& function); + + MachineLearningPipelineState* newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, NS::Error** error); + CompilerTask* newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandler completionHandler); + CompilerTask* newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* pDescriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandlerFunction& function); + + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + MTL::RenderPipelineState* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error); + CompilerTask* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + CompilerTask* newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + CompilerTask* newRenderPipelineState(const MTL4::PipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function); + MTL::RenderPipelineState* newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, NS::Error** error); + CompilerTask* newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + CompilerTask* newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* pDescriptor, const MTL::RenderPipelineState* pPipeline, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function); + + PipelineDataSetSerializer* pipelineDataSetSerializer() const; +}; + +} +_MTL_INLINE MTL4::CompilerDescriptor* MTL4::CompilerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CompilerDescriptor)); +} + +_MTL_INLINE MTL4::CompilerDescriptor* MTL4::CompilerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::CompilerDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializer* MTL4::CompilerDescriptor::pipelineDataSetSerializer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pipelineDataSetSerializer)); +} + +_MTL_INLINE void MTL4::CompilerDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::CompilerDescriptor::setPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializer* pipelineDataSetSerializer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPipelineDataSetSerializer_), pipelineDataSetSerializer); +} + +_MTL_INLINE MTL4::CompilerTaskOptions* MTL4::CompilerTaskOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CompilerTaskOptions)); +} + +_MTL_INLINE MTL4::CompilerTaskOptions* MTL4::CompilerTaskOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL4::CompilerTaskOptions::lookupArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lookupArchives)); +} + +_MTL_INLINE void MTL4::CompilerTaskOptions::setLookupArchives(const NS::Array* lookupArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLookupArchives_), lookupArchives); +} + +_MTL_INLINE MTL::Device* MTL4::Compiler::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL4::Compiler::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::BinaryFunction* MTL4::Compiler::newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_error_), descriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newBinaryFunction(const MTL4::BinaryFunctionDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL4::NewBinaryFunctionCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_completionHandler_), descriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_error_), descriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_completionHandler_), descriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* descriptor, const MTL4::PipelineStageDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewComputePipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newComputePipelineState(const MTL4::ComputePipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewComputePipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewComputePipelineStateCompletionHandlerFunction blockFunction = function; + return newComputePipelineState(pDescriptor, options, ^(MTL::ComputePipelineState* pPipeline, NS::Error* pError) { blockFunction(pPipeline, pError); }); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL4::Compiler::newDynamicLibrary(const MTL::Library* library, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibrary_error_), library, error); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL4::Compiler::newDynamicLibrary(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibraryWithURL_error_), url, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const MTL::Library* library, const MTL::NewDynamicLibraryCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibrary_completionHandler_), library, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const NS::URL* url, const MTL::NewDynamicLibraryCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibraryWithURL_completionHandler_), url, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const MTL::Library* pLibrary, const MTL::NewDynamicLibraryCompletionHandlerFunction& function) +{ + __block MTL::NewDynamicLibraryCompletionHandlerFunction blockFunction = function; + return newDynamicLibrary(pLibrary, ^(MTL::DynamicLibrary* pLibraryRef, NS::Error* pError) { blockFunction(pLibraryRef, pError); }); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newDynamicLibrary(const NS::URL* pURL, const MTL::NewDynamicLibraryCompletionHandlerFunction& function) +{ + __block MTL::NewDynamicLibraryCompletionHandlerFunction blockFunction = function; + return newDynamicLibrary(pURL, ^(MTL::DynamicLibrary* pLibrary, NS::Error* pError) { blockFunction(pLibrary, pError); }); +} + +_MTL_INLINE MTL::Library* MTL4::Compiler::newLibrary(const MTL4::LibraryDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newLibrary(const MTL4::LibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newLibrary(const MTL4::LibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& function) +{ + __block MTL::NewLibraryCompletionHandlerFunction blockFunction = function; + return newLibrary(pDescriptor, ^(MTL::Library* pLibrary, NS::Error* pError) { blockFunction(pLibrary, pError); }); +} + +_MTL_INLINE MTL4::MachineLearningPipelineState* MTL4::Compiler::newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMachineLearningPipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* descriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMachineLearningPipelineStateWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newMachineLearningPipelineState(const MTL4::MachineLearningPipelineDescriptor* pDescriptor, const MTL4::NewMachineLearningPipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewMachineLearningPipelineStateCompletionHandlerFunction blockFunction = function; + return newMachineLearningPipelineState(pDescriptor, ^(MTL4::MachineLearningPipelineState* pPipeline, NS::Error* pError) { blockFunction(pPipeline, pError); }); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_error_), descriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_completionHandler_), descriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* descriptor, const MTL4::RenderPipelineDynamicLinkingDescriptor* dynamicLinkingDescriptor, const MTL4::CompilerTaskOptions* compilerTaskOptions, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_), descriptor, dynamicLinkingDescriptor, compilerTaskOptions, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineState(const MTL4::PipelineDescriptor* pDescriptor, const MTL4::CompilerTaskOptions* options, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewRenderPipelineStateCompletionHandlerFunction blockFunction = function; + return newRenderPipelineState(pDescriptor, options, ^(MTL::RenderPipelineState* pPipeline, NS::Error* pError) { blockFunction(pPipeline, pError); }); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL4::Compiler::newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_error_), descriptor, pipeline, error); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* descriptor, const MTL::RenderPipelineState* pipeline, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_completionHandler_), descriptor, pipeline, completionHandler); +} + +_MTL_INLINE MTL4::CompilerTask* MTL4::Compiler::newRenderPipelineStateBySpecialization(const MTL4::PipelineDescriptor* pDescriptor, const MTL::RenderPipelineState* pPipeline, const MTL4::NewRenderPipelineStateCompletionHandlerFunction& function) +{ + __block MTL4::NewRenderPipelineStateCompletionHandlerFunction blockFunction = function; + return newRenderPipelineStateBySpecialization(pDescriptor, pPipeline, ^(MTL::RenderPipelineState* pPipelineRef, NS::Error* pError) { blockFunction(pPipelineRef, pError); }); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializer* MTL4::Compiler::pipelineDataSetSerializer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pipelineDataSetSerializer)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4CompilerTask.hpp b/dist/include/metal_cpp/Metal/MTL4CompilerTask.hpp new file mode 100644 index 0000000..a1ee9cd --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4CompilerTask.hpp @@ -0,0 +1,63 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4CompilerTask.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class Compiler; +_MTL_ENUM(NS::Integer, CompilerTaskStatus) { + CompilerTaskStatusNone = 0, + CompilerTaskStatusScheduled = 1, + CompilerTaskStatusCompiling = 2, + CompilerTaskStatusFinished = 3, +}; + +class CompilerTask : public NS::Referencing +{ +public: + Compiler* compiler() const; + + CompilerTaskStatus status() const; + + void waitUntilCompleted(); +}; + +} + +_MTL_INLINE MTL4::Compiler* MTL4::CompilerTask::compiler() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compiler)); +} + +_MTL_INLINE MTL4::CompilerTaskStatus MTL4::CompilerTask::status() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(status)); +} + +_MTL_INLINE void MTL4::CompilerTask::waitUntilCompleted() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilCompleted)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4ComputeCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTL4ComputeCommandEncoder.hpp new file mode 100644 index 0000000..7ef19da --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4ComputeCommandEncoder.hpp @@ -0,0 +1,300 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4ComputeCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTL4Counters.hpp" +#include "MTLAccelerationStructure.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLBlitCommandEncoder.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL4 +{ +class AccelerationStructureDescriptor; +class ArgumentTable; +class CounterHeap; +} + +namespace MTL +{ +class AccelerationStructure; +class Buffer; +class ComputePipelineState; +class IndirectCommandBuffer; +class Tensor; +class TensorExtents; +class Texture; +} + +namespace MTL4 +{ +class ComputeCommandEncoder : public NS::Referencing +{ +public: + void buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL4::BufferRange scratchBuffer); + + void copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options); + + void copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions); + + void copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options); + + void copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex); + + void dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup); + void dispatchThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerThreadgroup); + + void dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup); + void dispatchThreads(MTL::GPUAddress indirectBuffer); + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, MTL::GPUAddress indirectRangeBuffer); + + void fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value); + + void generateMipmaps(const MTL::Texture* texture); + + void optimizeContentsForCPUAccess(const MTL::Texture* texture); + void optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeContentsForGPUAccess(const MTL::Texture* texture); + void optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range); + + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer); + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer, MTL::AccelerationStructureRefitOptions options); + + void resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range); + + void setArgumentTable(const MTL4::ArgumentTable* argumentTable); + + void setComputePipelineState(const MTL::ComputePipelineState* state); + + void setImageblockWidth(NS::UInteger width, NS::UInteger height); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + MTL::Stages stages(); + + void writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL4::BufferRange buffer); + + void writeTimestamp(MTL4::TimestampGranularity granularity, const MTL4::CounterHeap* counterHeap, NS::UInteger index); +}; + +} +_MTL_INLINE void MTL4::ComputeCommandEncoder::buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL4::BufferRange scratchBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(buildAccelerationStructure_descriptor_scratchBuffer_), accelerationStructure, descriptor, scratchBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAndCompactAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_), sourceBuffer, sourceOffset, destinationBuffer, destinationOffset, size); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_options_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin, options); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTensor_sourceOrigin_sourceDimensions_toTensor_destinationOrigin_destinationDimensions_), sourceTensor, sourceOrigin, sourceDimensions, destinationTensor, destinationOrigin, destinationDimensions); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_toTexture_), sourceTexture, destinationTexture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_toTexture_destinationSlice_destinationLevel_sliceCount_levelCount_), sourceTexture, sourceSlice, sourceLevel, destinationTexture, destinationSlice, destinationLevel, sliceCount, levelCount); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_options_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage, options); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyIndirectCommandBuffer_sourceRange_destination_destinationIndex_), source, sourceRange, destination, destinationIndex); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroups_threadsPerThreadgroup_), threadgroupsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroupsWithIndirectBuffer_threadsPerThreadgroup_), indirectBuffer, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreads_threadsPerThreadgroup_), threadsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::dispatchThreads(MTL::GPUAddress indirectBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadsWithIndirectBuffer_), indirectBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, MTL::GPUAddress indirectRangeBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_), indirectCommandbuffer, indirectRangeBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(fillBuffer_range_value_), buffer, range, value); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::generateMipmaps(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(generateMipmapsForTexture_), texture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_), texture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_), texture); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeIndirectCommandBuffer_withRange_), indirectCommandBuffer, range); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL4::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL4::BufferRange scratchBuffer, MTL::AccelerationStructureRefitOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_options_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer, options); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetCommandsInBuffer_withRange_), buffer, range); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setArgumentTable(const MTL4::ArgumentTable* argumentTable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentTable_), argumentTable); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setComputePipelineState(const MTL::ComputePipelineState* state) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_), state); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setImageblockWidth(NS::UInteger width, NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockWidth_height_), width, height); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE MTL::Stages MTL4::ComputeCommandEncoder::stages() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stages)); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL4::BufferRange buffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeCompactedAccelerationStructureSize_toBuffer_), accelerationStructure, buffer); +} + +_MTL_INLINE void MTL4::ComputeCommandEncoder::writeTimestamp(MTL4::TimestampGranularity granularity, const MTL4::CounterHeap* counterHeap, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeTimestampWithGranularity_intoHeap_atIndex_), granularity, counterHeap, index); +} diff --git a/dist/include/metal_cpp/Metal/MTL4ComputePipeline.hpp b/dist/include/metal_cpp/Metal/MTL4ComputePipeline.hpp new file mode 100644 index 0000000..a808431 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4ComputePipeline.hpp @@ -0,0 +1,158 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4ComputePipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL4 +{ +class ComputePipelineDescriptor; +class FunctionDescriptor; +class StaticLinkingDescriptor; + +class ComputePipelineDescriptor : public NS::Copying +{ +public: + static ComputePipelineDescriptor* alloc(); + + FunctionDescriptor* computeFunctionDescriptor() const; + + ComputePipelineDescriptor* init(); + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + MTL::Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setComputeFunctionDescriptor(const MTL4::FunctionDescriptor* computeFunctionDescriptor); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor); + + void setSupportBinaryLinking(bool supportBinaryLinking); + + void setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers); + + void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth); + + StaticLinkingDescriptor* staticLinkingDescriptor() const; + + bool supportBinaryLinking() const; + + IndirectCommandBufferSupportState supportIndirectCommandBuffers() const; + + bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const; +}; + +} +_MTL_INLINE MTL4::ComputePipelineDescriptor* MTL4::ComputePipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4ComputePipelineDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::ComputePipelineDescriptor::computeFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeFunctionDescriptor)); +} + +_MTL_INLINE MTL4::ComputePipelineDescriptor* MTL4::ComputePipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::ComputePipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL4::ComputePipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setComputeFunctionDescriptor(const MTL4::FunctionDescriptor* computeFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputeFunctionDescriptor_), computeFunctionDescriptor); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStaticLinkingDescriptor_), staticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setSupportBinaryLinking(bool supportBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportBinaryLinking_), supportBinaryLinking); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL4::ComputePipelineDescriptor::setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadGroupSizeIsMultipleOfThreadExecutionWidth_), threadGroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::ComputePipelineDescriptor::staticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(staticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::ComputePipelineDescriptor::supportBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportBinaryLinking)); +} + +_MTL_INLINE MTL4::IndirectCommandBufferSupportState MTL4::ComputePipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL4::ComputePipelineDescriptor::threadGroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadGroupSizeIsMultipleOfThreadExecutionWidth)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4Counters.hpp b/dist/include/metal_cpp/Metal/MTL4Counters.hpp new file mode 100644 index 0000000..b507b76 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4Counters.hpp @@ -0,0 +1,138 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4Counters.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +#include + +namespace MTL4 +{ +class CounterHeapDescriptor; +_MTL_ENUM(NS::Integer, CounterHeapType) { + CounterHeapTypeInvalid, + CounterHeapTypeTimestamp, +}; + +_MTL_ENUM(NS::Integer, TimestampGranularity) { + TimestampGranularityRelaxed = 0, + TimestampGranularityPrecise = 1, +}; + +struct TimestampHeapEntry +{ + uint64_t timestamp; +} _MTL_PACKED; + +class CounterHeapDescriptor : public NS::Copying +{ +public: + static CounterHeapDescriptor* alloc(); + + NS::UInteger count() const; + + CounterHeapDescriptor* init(); + + void setCount(NS::UInteger count); + + void setType(MTL4::CounterHeapType type); + CounterHeapType type() const; +}; +class CounterHeap : public NS::Referencing +{ +public: + NS::UInteger count() const; + void invalidateCounterRange(NS::Range range); + + NS::String* label() const; + + NS::Data* resolveCounterRange(NS::Range range); + + void setLabel(const NS::String* label); + + CounterHeapType type() const; +}; + +} + +_MTL_INLINE MTL4::CounterHeapDescriptor* MTL4::CounterHeapDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4CounterHeapDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL4::CounterHeapDescriptor::count() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(count)); +} + +_MTL_INLINE MTL4::CounterHeapDescriptor* MTL4::CounterHeapDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::CounterHeapDescriptor::setCount(NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCount_), count); +} + +_MTL_INLINE void MTL4::CounterHeapDescriptor::setType(MTL4::CounterHeapType type) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setType_), type); +} + +_MTL_INLINE MTL4::CounterHeapType MTL4::CounterHeapDescriptor::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE NS::UInteger MTL4::CounterHeap::count() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(count)); +} + +_MTL_INLINE void MTL4::CounterHeap::invalidateCounterRange(NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(invalidateCounterRange_), range); +} + +_MTL_INLINE NS::String* MTL4::CounterHeap::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::Data* MTL4::CounterHeap::resolveCounterRange(NS::Range range) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounterRange_), range); +} + +_MTL_INLINE void MTL4::CounterHeap::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL4::CounterHeapType MTL4::CounterHeap::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4FunctionDescriptor.hpp b/dist/include/metal_cpp/Metal/MTL4FunctionDescriptor.hpp new file mode 100644 index 0000000..9049677 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4FunctionDescriptor.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal//MTL4FunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; + +class FunctionDescriptor : public NS::Copying +{ +public: + static FunctionDescriptor* alloc(); + + FunctionDescriptor* init(); +}; + +} +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::FunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4FunctionDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::FunctionDescriptor::init() +{ + return NS::Object::init(); +} diff --git a/dist/include/metal_cpp/Metal/MTL4LibraryDescriptor.hpp b/dist/include/metal_cpp/Metal/MTL4LibraryDescriptor.hpp new file mode 100644 index 0000000..bc491b6 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4LibraryDescriptor.hpp @@ -0,0 +1,98 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4LibraryDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class LibraryDescriptor; +} + +namespace MTL +{ +class CompileOptions; +} + +namespace MTL4 +{ +class LibraryDescriptor : public NS::Copying +{ +public: + static LibraryDescriptor* alloc(); + + LibraryDescriptor* init(); + + NS::String* name() const; + + MTL::CompileOptions* options() const; + + void setName(const NS::String* name); + + void setOptions(const MTL::CompileOptions* options); + + void setSource(const NS::String* source); + NS::String* source() const; +}; + +} +_MTL_INLINE MTL4::LibraryDescriptor* MTL4::LibraryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4LibraryDescriptor)); +} + +_MTL_INLINE MTL4::LibraryDescriptor* MTL4::LibraryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::LibraryDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::CompileOptions* MTL4::LibraryDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL4::LibraryDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE void MTL4::LibraryDescriptor::setOptions(const MTL::CompileOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} + +_MTL_INLINE void MTL4::LibraryDescriptor::setSource(const NS::String* source) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSource_), source); +} + +_MTL_INLINE NS::String* MTL4::LibraryDescriptor::source() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(source)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4LibraryFunctionDescriptor.hpp b/dist/include/metal_cpp/Metal/MTL4LibraryFunctionDescriptor.hpp new file mode 100644 index 0000000..1dec4bf --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4LibraryFunctionDescriptor.hpp @@ -0,0 +1,86 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4LibraryFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class LibraryFunctionDescriptor; +} + +namespace MTL +{ +class Library; +} + +namespace MTL4 +{ +class LibraryFunctionDescriptor : public NS::Copying +{ +public: + static LibraryFunctionDescriptor* alloc(); + + LibraryFunctionDescriptor* init(); + + MTL::Library* library() const; + + NS::String* name() const; + + void setLibrary(const MTL::Library* library); + + void setName(const NS::String* name); +}; + +} +_MTL_INLINE MTL4::LibraryFunctionDescriptor* MTL4::LibraryFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4LibraryFunctionDescriptor)); +} + +_MTL_INLINE MTL4::LibraryFunctionDescriptor* MTL4::LibraryFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Library* MTL4::LibraryFunctionDescriptor::library() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(library)); +} + +_MTL_INLINE NS::String* MTL4::LibraryFunctionDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE void MTL4::LibraryFunctionDescriptor::setLibrary(const MTL::Library* library) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLibrary_), library); +} + +_MTL_INLINE void MTL4::LibraryFunctionDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} diff --git a/dist/include/metal_cpp/Metal/MTL4LinkingDescriptor.hpp b/dist/include/metal_cpp/Metal/MTL4LinkingDescriptor.hpp new file mode 100644 index 0000000..ef5900b --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4LinkingDescriptor.hpp @@ -0,0 +1,204 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4LinkingDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class PipelineStageDynamicLinkingDescriptor; +class RenderPipelineDynamicLinkingDescriptor; +class StaticLinkingDescriptor; + +class StaticLinkingDescriptor : public NS::Copying +{ +public: + static StaticLinkingDescriptor* alloc(); + + NS::Array* functionDescriptors() const; + + NS::Dictionary* groups() const; + + StaticLinkingDescriptor* init(); + + NS::Array* privateFunctionDescriptors() const; + + void setFunctionDescriptors(const NS::Array* functionDescriptors); + + void setGroups(const NS::Dictionary* groups); + + void setPrivateFunctionDescriptors(const NS::Array* privateFunctionDescriptors); +}; +class PipelineStageDynamicLinkingDescriptor : public NS::Copying +{ +public: + static PipelineStageDynamicLinkingDescriptor* alloc(); + + NS::Array* binaryLinkedFunctions() const; + + PipelineStageDynamicLinkingDescriptor* init(); + + NS::UInteger maxCallStackDepth() const; + + NS::Array* preloadedLibraries() const; + + void setBinaryLinkedFunctions(const NS::Array* binaryLinkedFunctions); + + void setMaxCallStackDepth(NS::UInteger maxCallStackDepth); + + void setPreloadedLibraries(const NS::Array* preloadedLibraries); +}; +class RenderPipelineDynamicLinkingDescriptor : public NS::Copying +{ +public: + static RenderPipelineDynamicLinkingDescriptor* alloc(); + + PipelineStageDynamicLinkingDescriptor* fragmentLinkingDescriptor() const; + + RenderPipelineDynamicLinkingDescriptor* init(); + + PipelineStageDynamicLinkingDescriptor* meshLinkingDescriptor() const; + + PipelineStageDynamicLinkingDescriptor* objectLinkingDescriptor() const; + + PipelineStageDynamicLinkingDescriptor* tileLinkingDescriptor() const; + + PipelineStageDynamicLinkingDescriptor* vertexLinkingDescriptor() const; +}; + +} +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::StaticLinkingDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4StaticLinkingDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::StaticLinkingDescriptor::functionDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptors)); +} + +_MTL_INLINE NS::Dictionary* MTL4::StaticLinkingDescriptor::groups() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(groups)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::StaticLinkingDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL4::StaticLinkingDescriptor::privateFunctionDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(privateFunctionDescriptors)); +} + +_MTL_INLINE void MTL4::StaticLinkingDescriptor::setFunctionDescriptors(const NS::Array* functionDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptors_), functionDescriptors); +} + +_MTL_INLINE void MTL4::StaticLinkingDescriptor::setGroups(const NS::Dictionary* groups) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGroups_), groups); +} + +_MTL_INLINE void MTL4::StaticLinkingDescriptor::setPrivateFunctionDescriptors(const NS::Array* privateFunctionDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrivateFunctionDescriptors_), privateFunctionDescriptors); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::PipelineStageDynamicLinkingDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineStageDynamicLinkingDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::PipelineStageDynamicLinkingDescriptor::binaryLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryLinkedFunctions)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::PipelineStageDynamicLinkingDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::PipelineStageDynamicLinkingDescriptor::maxCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCallStackDepth)); +} + +_MTL_INLINE NS::Array* MTL4::PipelineStageDynamicLinkingDescriptor::preloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preloadedLibraries)); +} + +_MTL_INLINE void MTL4::PipelineStageDynamicLinkingDescriptor::setBinaryLinkedFunctions(const NS::Array* binaryLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryLinkedFunctions_), binaryLinkedFunctions); +} + +_MTL_INLINE void MTL4::PipelineStageDynamicLinkingDescriptor::setMaxCallStackDepth(NS::UInteger maxCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth); +} + +_MTL_INLINE void MTL4::PipelineStageDynamicLinkingDescriptor::setPreloadedLibraries(const NS::Array* preloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreloadedLibraries_), preloadedLibraries); +} + +_MTL_INLINE MTL4::RenderPipelineDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineDynamicLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::fragmentLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentLinkingDescriptor)); +} + +_MTL_INLINE MTL4::RenderPipelineDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::meshLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::objectLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::tileLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileLinkingDescriptor)); +} + +_MTL_INLINE MTL4::PipelineStageDynamicLinkingDescriptor* MTL4::RenderPipelineDynamicLinkingDescriptor::vertexLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexLinkingDescriptor)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4MachineLearningCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTL4MachineLearningCommandEncoder.hpp new file mode 100644 index 0000000..4d3cff6 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4MachineLearningCommandEncoder.hpp @@ -0,0 +1,66 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4MachineLearningCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class ArgumentTable; +class MachineLearningPipelineState; +} + +namespace MTL +{ +class Heap; +} + +namespace MTL4 +{ +class MachineLearningCommandEncoder : public NS::Referencing +{ +public: + void dispatchNetwork(const MTL::Heap* heap); + + void setArgumentTable(const MTL4::ArgumentTable* argumentTable); + + void setPipelineState(const MTL4::MachineLearningPipelineState* pipelineState); +}; + +} +_MTL_INLINE void MTL4::MachineLearningCommandEncoder::dispatchNetwork(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchNetworkWithIntermediatesHeap_), heap); +} + +_MTL_INLINE void MTL4::MachineLearningCommandEncoder::setArgumentTable(const MTL4::ArgumentTable* argumentTable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentTable_), argumentTable); +} + +_MTL_INLINE void MTL4::MachineLearningCommandEncoder::setPipelineState(const MTL4::MachineLearningPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPipelineState_), pipelineState); +} diff --git a/dist/include/metal_cpp/Metal/MTL4MachineLearningPipeline.hpp b/dist/include/metal_cpp/Metal/MTL4MachineLearningPipeline.hpp new file mode 100644 index 0000000..713569f --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4MachineLearningPipeline.hpp @@ -0,0 +1,172 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4MachineLearningPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class MachineLearningPipelineDescriptor; +class MachineLearningPipelineReflection; +} + +namespace MTL +{ +class Device; +class TensorExtents; +} + +namespace MTL4 +{ +class MachineLearningPipelineDescriptor : public NS::Copying +{ +public: + static MachineLearningPipelineDescriptor* alloc(); + + MachineLearningPipelineDescriptor* init(); + + MTL::TensorExtents* inputDimensionsAtBufferIndex(NS::Integer bufferIndex); + + NS::String* label() const; + + FunctionDescriptor* machineLearningFunctionDescriptor() const; + + void reset(); + + void setInputDimensions(const MTL::TensorExtents* dimensions, NS::Integer bufferIndex); + void setInputDimensions(const NS::Array* dimensions, NS::Range range); + + void setLabel(const NS::String* label); + + void setMachineLearningFunctionDescriptor(const MTL4::FunctionDescriptor* machineLearningFunctionDescriptor); +}; +class MachineLearningPipelineReflection : public NS::Referencing +{ +public: + static MachineLearningPipelineReflection* alloc(); + + NS::Array* bindings() const; + + MachineLearningPipelineReflection* init(); +}; +class MachineLearningPipelineState : public NS::Referencing +{ +public: + MTL::Device* device() const; + + NS::UInteger intermediatesHeapSize() const; + + NS::String* label() const; + + MachineLearningPipelineReflection* reflection() const; +}; + +} +_MTL_INLINE MTL4::MachineLearningPipelineDescriptor* MTL4::MachineLearningPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4MachineLearningPipelineDescriptor)); +} + +_MTL_INLINE MTL4::MachineLearningPipelineDescriptor* MTL4::MachineLearningPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TensorExtents* MTL4::MachineLearningPipelineDescriptor::inputDimensionsAtBufferIndex(NS::Integer bufferIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inputDimensionsAtBufferIndex_), bufferIndex); +} + +_MTL_INLINE NS::String* MTL4::MachineLearningPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MachineLearningPipelineDescriptor::machineLearningFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(machineLearningFunctionDescriptor)); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setInputDimensions(const MTL::TensorExtents* dimensions, NS::Integer bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputDimensions_atBufferIndex_), dimensions, bufferIndex); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setInputDimensions(const NS::Array* dimensions, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputDimensions_withRange_), dimensions, range); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::MachineLearningPipelineDescriptor::setMachineLearningFunctionDescriptor(const MTL4::FunctionDescriptor* machineLearningFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMachineLearningFunctionDescriptor_), machineLearningFunctionDescriptor); +} + +_MTL_INLINE MTL4::MachineLearningPipelineReflection* MTL4::MachineLearningPipelineReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4MachineLearningPipelineReflection)); +} + +_MTL_INLINE NS::Array* MTL4::MachineLearningPipelineReflection::bindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bindings)); +} + +_MTL_INLINE MTL4::MachineLearningPipelineReflection* MTL4::MachineLearningPipelineReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Device* MTL4::MachineLearningPipelineState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::UInteger MTL4::MachineLearningPipelineState::intermediatesHeapSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(intermediatesHeapSize)); +} + +_MTL_INLINE NS::String* MTL4::MachineLearningPipelineState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::MachineLearningPipelineReflection* MTL4::MachineLearningPipelineState::reflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflection)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4MeshRenderPipeline.hpp b/dist/include/metal_cpp/Metal/MTL4MeshRenderPipeline.hpp new file mode 100644 index 0000000..f66dffe --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4MeshRenderPipeline.hpp @@ -0,0 +1,413 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4MeshRenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTL4RenderPipeline.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class MeshRenderPipelineDescriptor; +class RenderPipelineColorAttachmentDescriptorArray; +class StaticLinkingDescriptor; + +class MeshRenderPipelineDescriptor : public NS::Copying +{ +public: + static MeshRenderPipelineDescriptor* alloc(); + + AlphaToCoverageState alphaToCoverageState() const; + + AlphaToOneState alphaToOneState() const; + + LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + FunctionDescriptor* fragmentFunctionDescriptor() const; + + StaticLinkingDescriptor* fragmentStaticLinkingDescriptor() const; + + MeshRenderPipelineDescriptor* init(); + + bool isRasterizationEnabled() const; + + NS::UInteger maxTotalThreadgroupsPerMeshGrid() const; + + NS::UInteger maxTotalThreadsPerMeshThreadgroup() const; + + NS::UInteger maxTotalThreadsPerObjectThreadgroup() const; + + NS::UInteger maxVertexAmplificationCount() const; + + FunctionDescriptor* meshFunctionDescriptor() const; + + StaticLinkingDescriptor* meshStaticLinkingDescriptor() const; + + bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + FunctionDescriptor* objectFunctionDescriptor() const; + + StaticLinkingDescriptor* objectStaticLinkingDescriptor() const; + + bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + NS::UInteger payloadMemoryLength() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + MTL::Size requiredThreadsPerMeshThreadgroup() const; + + MTL::Size requiredThreadsPerObjectThreadgroup() const; + + void reset(); + + void setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState); + + void setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState); + + void setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState); + + void setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor); + + void setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor); + + void setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid); + + void setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup); + + void setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setMeshFunctionDescriptor(const MTL4::FunctionDescriptor* meshFunctionDescriptor); + + void setMeshStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* meshStaticLinkingDescriptor); + + void setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setObjectFunctionDescriptor(const MTL4::FunctionDescriptor* objectFunctionDescriptor); + + void setObjectStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* objectStaticLinkingDescriptor); + + void setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setPayloadMemoryLength(NS::UInteger payloadMemoryLength); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup); + + void setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup); + + void setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking); + + void setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers); + + void setSupportMeshBinaryLinking(bool supportMeshBinaryLinking); + + void setSupportObjectBinaryLinking(bool supportObjectBinaryLinking); + + bool supportFragmentBinaryLinking() const; + + IndirectCommandBufferSupportState supportIndirectCommandBuffers() const; + + bool supportMeshBinaryLinking() const; + + bool supportObjectBinaryLinking() const; +}; + +} +_MTL_INLINE MTL4::MeshRenderPipelineDescriptor* MTL4::MeshRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4MeshRenderPipelineDescriptor)); +} + +_MTL_INLINE MTL4::AlphaToCoverageState MTL4::MeshRenderPipelineDescriptor::alphaToCoverageState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToCoverageState)); +} + +_MTL_INLINE MTL4::AlphaToOneState MTL4::MeshRenderPipelineDescriptor::alphaToOneState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToOneState)); +} + +_MTL_INLINE MTL4::LogicalToPhysicalColorAttachmentMappingState MTL4::MeshRenderPipelineDescriptor::colorAttachmentMappingState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachmentMappingState)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::MeshRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MeshRenderPipelineDescriptor::fragmentFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::MeshRenderPipelineDescriptor::fragmentStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentStaticLinkingDescriptor)); +} + +_MTL_INLINE MTL4::MeshRenderPipelineDescriptor* MTL4::MeshRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxTotalThreadgroupsPerMeshGrid() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadgroupsPerMeshGrid)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxTotalThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxTotalThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MeshRenderPipelineDescriptor::meshFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::MeshRenderPipelineDescriptor::meshStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshStaticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::MeshRenderPipelineDescriptor::objectFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::MeshRenderPipelineDescriptor::objectStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectStaticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::payloadMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(payloadMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL4::MeshRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE MTL::Size MTL4::MeshRenderPipelineDescriptor::requiredThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL4::MeshRenderPipelineDescriptor::requiredThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageState_), alphaToCoverageState); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneState_), alphaToOneState); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMappingState_), colorAttachmentMappingState); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunctionDescriptor_), fragmentFunctionDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentStaticLinkingDescriptor_), fragmentStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadgroupsPerMeshGrid_), maxTotalThreadgroupsPerMeshGrid); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerMeshThreadgroup_), maxTotalThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerObjectThreadgroup_), maxTotalThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMeshFunctionDescriptor(const MTL4::FunctionDescriptor* meshFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshFunctionDescriptor_), meshFunctionDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMeshStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* meshStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshStaticLinkingDescriptor_), meshStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth_), meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setObjectFunctionDescriptor(const MTL4::FunctionDescriptor* objectFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectFunctionDescriptor_), objectFunctionDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setObjectStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* objectStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectStaticLinkingDescriptor_), objectStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth_), objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setPayloadMemoryLength(NS::UInteger payloadMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPayloadMemoryLength_), payloadMemoryLength); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerMeshThreadgroup_), requiredThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerObjectThreadgroup_), requiredThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportFragmentBinaryLinking_), supportFragmentBinaryLinking); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportMeshBinaryLinking(bool supportMeshBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportMeshBinaryLinking_), supportMeshBinaryLinking); +} + +_MTL_INLINE void MTL4::MeshRenderPipelineDescriptor::setSupportObjectBinaryLinking(bool supportObjectBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportObjectBinaryLinking_), supportObjectBinaryLinking); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::supportFragmentBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportFragmentBinaryLinking)); +} + +_MTL_INLINE MTL4::IndirectCommandBufferSupportState MTL4::MeshRenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::supportMeshBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportMeshBinaryLinking)); +} + +_MTL_INLINE bool MTL4::MeshRenderPipelineDescriptor::supportObjectBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportObjectBinaryLinking)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4PipelineDataSetSerializer.hpp b/dist/include/metal_cpp/Metal/MTL4PipelineDataSetSerializer.hpp new file mode 100644 index 0000000..9dbd610 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4PipelineDataSetSerializer.hpp @@ -0,0 +1,85 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4PipelineDataSetSerializer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class PipelineDataSetSerializerDescriptor; + +_MTL_OPTIONS(NS::UInteger, PipelineDataSetSerializerConfiguration) { + PipelineDataSetSerializerConfigurationCaptureDescriptors = 1, + PipelineDataSetSerializerConfigurationCaptureBinaries = 1 << 1, +}; + +class PipelineDataSetSerializerDescriptor : public NS::Copying +{ +public: + static PipelineDataSetSerializerDescriptor* alloc(); + + PipelineDataSetSerializerConfiguration configuration() const; + + PipelineDataSetSerializerDescriptor* init(); + + void setConfiguration(MTL4::PipelineDataSetSerializerConfiguration configuration); +}; +class PipelineDataSetSerializer : public NS::Referencing +{ +public: + bool serializeAsArchiveAndFlushToURL(const NS::URL* url, NS::Error** error); + + NS::Data* serializeAsPipelinesScript(NS::Error** error); +}; + +} +_MTL_INLINE MTL4::PipelineDataSetSerializerDescriptor* MTL4::PipelineDataSetSerializerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineDataSetSerializerDescriptor)); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializerConfiguration MTL4::PipelineDataSetSerializerDescriptor::configuration() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(configuration)); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializerDescriptor* MTL4::PipelineDataSetSerializerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::PipelineDataSetSerializerDescriptor::setConfiguration(MTL4::PipelineDataSetSerializerConfiguration configuration) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConfiguration_), configuration); +} + +_MTL_INLINE bool MTL4::PipelineDataSetSerializer::serializeAsArchiveAndFlushToURL(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeAsArchiveAndFlushToURL_error_), url, error); +} + +_MTL_INLINE NS::Data* MTL4::PipelineDataSetSerializer::serializeAsPipelinesScript(NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeAsPipelinesScriptWithError_), error); +} diff --git a/dist/include/metal_cpp/Metal/MTL4PipelineState.hpp b/dist/include/metal_cpp/Metal/MTL4PipelineState.hpp new file mode 100644 index 0000000..cecefa8 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4PipelineState.hpp @@ -0,0 +1,150 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4PipelineState.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPipeline.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class PipelineDescriptor; +class PipelineOptions; +_MTL_ENUM(NS::Integer, AlphaToOneState) { + AlphaToOneStateDisabled = 0, + AlphaToOneStateEnabled = 1, +}; + +_MTL_ENUM(NS::Integer, AlphaToCoverageState) { + AlphaToCoverageStateDisabled = 0, + AlphaToCoverageStateEnabled = 1, +}; + +_MTL_ENUM(NS::Integer, BlendState) { + BlendStateDisabled = 0, + BlendStateEnabled = 1, + BlendStateUnspecialized = 2, +}; + +_MTL_ENUM(NS::Integer, IndirectCommandBufferSupportState) { + IndirectCommandBufferSupportStateDisabled = 0, + IndirectCommandBufferSupportStateEnabled = 1, +}; + +_MTL_OPTIONS(NS::UInteger, ShaderReflection) { + ShaderReflectionNone = 0, + ShaderReflectionBindingInfo = 1, + ShaderReflectionBufferTypeInfo = 1 << 1, +}; + +class PipelineOptions : public NS::Copying +{ +public: + static PipelineOptions* alloc(); + + PipelineOptions* init(); + + void setShaderReflection(MTL4::ShaderReflection shaderReflection); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + ShaderReflection shaderReflection() const; + + MTL::ShaderValidation shaderValidation() const; +}; +class PipelineDescriptor : public NS::Copying +{ +public: + static PipelineDescriptor* alloc(); + + PipelineDescriptor* init(); + + NS::String* label() const; + + PipelineOptions* options() const; + + void setLabel(const NS::String* label); + + void setOptions(const MTL4::PipelineOptions* options); +}; + +} +_MTL_INLINE MTL4::PipelineOptions* MTL4::PipelineOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineOptions)); +} + +_MTL_INLINE MTL4::PipelineOptions* MTL4::PipelineOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::PipelineOptions::setShaderReflection(MTL4::ShaderReflection shaderReflection) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderReflection_), shaderReflection); +} + +_MTL_INLINE void MTL4::PipelineOptions::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE MTL4::ShaderReflection MTL4::PipelineOptions::shaderReflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderReflection)); +} + +_MTL_INLINE MTL::ShaderValidation MTL4::PipelineOptions::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL4::PipelineDescriptor* MTL4::PipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4PipelineDescriptor)); +} + +_MTL_INLINE MTL4::PipelineDescriptor* MTL4::PipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL4::PipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL4::PipelineOptions* MTL4::PipelineDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL4::PipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL4::PipelineDescriptor::setOptions(const MTL4::PipelineOptions* options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} diff --git a/dist/include/metal_cpp/Metal/MTL4RenderCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTL4RenderCommandEncoder.hpp new file mode 100644 index 0000000..0dd01f4 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4RenderCommandEncoder.hpp @@ -0,0 +1,340 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4RenderCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTL4Counters.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLRenderPass.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL4 +{ +class ArgumentTable; +class CounterHeap; +} + +namespace MTL +{ +class DepthStencilState; +class IndirectCommandBuffer; +class LogicalToPhysicalColorAttachmentMap; +class RenderPipelineState; +struct ScissorRect; +struct VertexAmplificationViewMapping; +struct Viewport; + +} +namespace MTL4 +{ +_MTL_OPTIONS(NS::UInteger, RenderEncoderOptions) { + RenderEncoderOptionNone = 0, + RenderEncoderOptionSuspending = 1, + RenderEncoderOptionResuming = 1 << 1, +}; + +class RenderCommandEncoder : public NS::Referencing +{ +public: + void dispatchThreadsPerTile(MTL::Size threadsPerTile); + + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, MTL::GPUAddress indirectBuffer); + + void drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + void drawMeshThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawPrimitives(MTL::PrimitiveType primitiveType, MTL::GPUAddress indirectBuffer); + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, MTL::GPUAddress indirectRangeBuffer); + + void setArgumentTable(const MTL4::ArgumentTable* argumentTable, MTL::RenderStages stages); + + void setBlendColor(float red, float green, float blue, float alpha); + + void setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping); + + void setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex); + + void setCullMode(MTL::CullMode cullMode); + + void setDepthBias(float depthBias, float slopeScale, float clamp); + + void setDepthClipMode(MTL::DepthClipMode depthClipMode); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState); + + void setDepthStoreAction(MTL::StoreAction storeAction); + + void setDepthTestBounds(float minBound, float maxBound); + + void setFrontFacingWinding(MTL::Winding frontFacingWinding); + + void setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipelineState); + + void setScissorRect(MTL::ScissorRect rect); + void setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count); + + void setStencilReferenceValue(uint32_t referenceValue); + void setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue); + + void setStencilStoreAction(MTL::StoreAction storeAction); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index); + + void setTriangleFillMode(MTL::TriangleFillMode fillMode); + + void setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings); + + void setViewport(MTL::Viewport viewport); + void setViewports(const MTL::Viewport* viewports, NS::UInteger count); + + void setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset); + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + void writeTimestamp(MTL4::TimestampGranularity granularity, MTL::RenderStages stage, const MTL4::CounterHeap* counterHeap, NS::UInteger index); +}; + +} +_MTL_INLINE void MTL4::RenderCommandEncoder::dispatchThreadsPerTile(MTL::Size threadsPerTile) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadsPerTile_), threadsPerTile); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_), primitiveType, indexCount, indexType, indexBuffer, indexBufferLength); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_), primitiveType, indexCount, indexType, indexBuffer, indexBufferLength, instanceCount); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_baseVertex_baseInstance_), primitiveType, indexCount, indexType, indexBuffer, indexBufferLength, instanceCount, baseVertex, baseInstance); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, MTL::GPUAddress indexBuffer, NS::UInteger indexBufferLength, MTL::GPUAddress indirectBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferLength_indirectBuffer_), primitiveType, indexType, indexBuffer, indexBufferLength, indirectBuffer); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadgroupsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawMeshThreadgroups(MTL::GPUAddress indirectBuffer, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroupsWithIndirectBuffer_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), indirectBuffer, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_), primitiveType, vertexStart, vertexCount); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_), primitiveType, vertexStart, vertexCount, instanceCount); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_), primitiveType, vertexStart, vertexCount, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, MTL::GPUAddress indirectBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_indirectBuffer_), primitiveType, indirectBuffer); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, MTL::GPUAddress indirectRangeBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_), indirectCommandBuffer, indirectRangeBuffer); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setArgumentTable(const MTL4::ArgumentTable* argumentTable, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentTable_atStages_), argumentTable, stages); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setBlendColor(float red, float green, float blue, float alpha) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendColorRed_green_blue_alpha_), red, green, blue, alpha); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMap_), mapping); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreAction_atIndex_), storeAction, colorAttachmentIndex); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setCullMode(MTL::CullMode cullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCullMode_), cullMode); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthBias(float depthBias, float slopeScale, float clamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthBias_slopeScale_clamp_), depthBias, slopeScale, clamp); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthClipMode(MTL::DepthClipMode depthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthClipMode_), depthClipMode); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthStencilState(const MTL::DepthStencilState* depthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_), depthStencilState); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreAction_), storeAction); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setDepthTestBounds(float minBound, float maxBound) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthTestMinBound_maxBound_), minBound, maxBound); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setFrontFacingWinding(MTL::Winding frontFacingWinding) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFacingWinding_), frontFacingWinding); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setRenderPipelineState(const MTL::RenderPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_), pipelineState); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setScissorRect(MTL::ScissorRect rect) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRect_), rect); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRects_count_), scissorRects, count); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setStencilReferenceValue(uint32_t referenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilReferenceValue_), referenceValue); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilFrontReferenceValue_backReferenceValue_), frontReferenceValue, backReferenceValue); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setStencilStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreAction_), storeAction); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_offset_atIndex_), length, offset, index); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setTriangleFillMode(MTL::TriangleFillMode fillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleFillMode_), fillMode); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAmplificationCount_viewMappings_), count, viewMappings); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setViewport(MTL::Viewport viewport) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewport_), viewport); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setViewports(const MTL::Viewport* viewports, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewports_count_), viewports, count); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultMode_offset_), mode, offset); +} + +_MTL_INLINE NS::UInteger MTL4::RenderCommandEncoder::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderCommandEncoder::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE void MTL4::RenderCommandEncoder::writeTimestamp(MTL4::TimestampGranularity granularity, MTL::RenderStages stage, const MTL4::CounterHeap* counterHeap, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeTimestampWithGranularity_afterStage_intoHeap_atIndex_), granularity, stage, counterHeap, index); +} diff --git a/dist/include/metal_cpp/Metal/MTL4RenderPass.hpp b/dist/include/metal_cpp/Metal/MTL4RenderPass.hpp new file mode 100644 index 0000000..c5aa9ed --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4RenderPass.hpp @@ -0,0 +1,280 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4RenderPass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPass.hpp" + +namespace MTL4 +{ +class RenderPassDescriptor; +} + +namespace MTL +{ +class Buffer; +class RasterizationRateMap; +class RenderPassColorAttachmentDescriptorArray; +class RenderPassDepthAttachmentDescriptor; +class RenderPassStencilAttachmentDescriptor; +struct SamplePosition; +} + +namespace MTL4 +{ +class RenderPassDescriptor : public NS::Copying +{ +public: + static RenderPassDescriptor* alloc(); + + MTL::RenderPassColorAttachmentDescriptorArray* colorAttachments() const; + + NS::UInteger defaultRasterSampleCount() const; + + MTL::RenderPassDepthAttachmentDescriptor* depthAttachment() const; + + NS::UInteger getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count); + + NS::UInteger imageblockSampleLength() const; + + RenderPassDescriptor* init(); + + MTL::RasterizationRateMap* rasterizationRateMap() const; + + NS::UInteger renderTargetArrayLength() const; + + NS::UInteger renderTargetHeight() const; + + NS::UInteger renderTargetWidth() const; + + void setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount); + + void setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment); + + void setImageblockSampleLength(NS::UInteger imageblockSampleLength); + + void setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap); + + void setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength); + + void setRenderTargetHeight(NS::UInteger renderTargetHeight); + + void setRenderTargetWidth(NS::UInteger renderTargetWidth); + + void setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count); + + void setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment); + + void setSupportColorAttachmentMapping(bool supportColorAttachmentMapping); + + void setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength); + + void setTileHeight(NS::UInteger tileHeight); + + void setTileWidth(NS::UInteger tileWidth); + + void setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer); + + void setVisibilityResultType(MTL::VisibilityResultType visibilityResultType); + + MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment() const; + + bool supportColorAttachmentMapping() const; + + NS::UInteger threadgroupMemoryLength() const; + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + MTL::Buffer* visibilityResultBuffer() const; + + MTL::VisibilityResultType visibilityResultType() const; +}; + +} +_MTL_INLINE MTL4::RenderPassDescriptor* MTL4::RenderPassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPassDescriptor)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL4::RenderPassDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::defaultRasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(defaultRasterSampleCount)); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL4::RenderPassDescriptor::depthAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachment)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(getSamplePositions_count_), positions, count); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::imageblockSampleLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockSampleLength)); +} + +_MTL_INLINE MTL4::RenderPassDescriptor* MTL4::RenderPassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateMap* MTL4::RenderPassDescriptor::rasterizationRateMap() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterizationRateMap)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::renderTargetArrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetArrayLength)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::renderTargetHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetHeight)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::renderTargetWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetWidth)); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDefaultRasterSampleCount_), defaultRasterSampleCount); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachment_), depthAttachment); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setImageblockSampleLength(NS::UInteger imageblockSampleLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockSampleLength_), imageblockSampleLength); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationRateMap_), rasterizationRateMap); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetArrayLength_), renderTargetArrayLength); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRenderTargetHeight(NS::UInteger renderTargetHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetHeight_), renderTargetHeight); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setRenderTargetWidth(NS::UInteger renderTargetWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetWidth_), renderTargetWidth); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplePositions_count_), positions, count); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachment_), stencilAttachment); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setSupportColorAttachmentMapping(bool supportColorAttachmentMapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportColorAttachmentMapping_), supportColorAttachmentMapping); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_), threadgroupMemoryLength); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setTileHeight(NS::UInteger tileHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileHeight_), tileHeight); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setTileWidth(NS::UInteger tileWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileWidth_), tileWidth); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultBuffer_), visibilityResultBuffer); +} + +_MTL_INLINE void MTL4::RenderPassDescriptor::setVisibilityResultType(MTL::VisibilityResultType visibilityResultType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultType_), visibilityResultType); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL4::RenderPassDescriptor::stencilAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachment)); +} + +_MTL_INLINE bool MTL4::RenderPassDescriptor::supportColorAttachmentMapping() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportColorAttachmentMapping)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::threadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPassDescriptor::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE MTL::Buffer* MTL4::RenderPassDescriptor::visibilityResultBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultBuffer)); +} + +_MTL_INLINE MTL::VisibilityResultType MTL4::RenderPassDescriptor::visibilityResultType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultType)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4RenderPipeline.hpp b/dist/include/metal_cpp/Metal/MTL4RenderPipeline.hpp new file mode 100644 index 0000000..fc2e5e6 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4RenderPipeline.hpp @@ -0,0 +1,587 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4RenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPipeline.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class RenderPipelineBinaryFunctionsDescriptor; +class RenderPipelineColorAttachmentDescriptor; +class RenderPipelineColorAttachmentDescriptorArray; +class RenderPipelineDescriptor; +class StaticLinkingDescriptor; +} + +namespace MTL +{ +class VertexDescriptor; +} + +namespace MTL4 +{ +_MTL_ENUM(NS::Integer, LogicalToPhysicalColorAttachmentMappingState) { + LogicalToPhysicalColorAttachmentMappingStateIdentity = 0, + LogicalToPhysicalColorAttachmentMappingStateInherited = 1, +}; + +class RenderPipelineColorAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPipelineColorAttachmentDescriptor* alloc(); + + MTL::BlendOperation alphaBlendOperation() const; + + BlendState blendingState() const; + + MTL::BlendFactor destinationAlphaBlendFactor() const; + + MTL::BlendFactor destinationRGBBlendFactor() const; + + RenderPipelineColorAttachmentDescriptor* init(); + + MTL::PixelFormat pixelFormat() const; + + void reset(); + + MTL::BlendOperation rgbBlendOperation() const; + + void setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation); + + void setBlendingState(MTL4::BlendState blendingState); + + void setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor); + + void setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation); + + void setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor); + + void setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor); + + void setWriteMask(MTL::ColorWriteMask writeMask); + + MTL::BlendFactor sourceAlphaBlendFactor() const; + + MTL::BlendFactor sourceRGBBlendFactor() const; + + MTL::ColorWriteMask writeMask() const; +}; + +class RenderPipelineColorAttachmentDescriptorArray : public NS::Copying +{ +public: + static RenderPipelineColorAttachmentDescriptorArray* alloc(); + + RenderPipelineColorAttachmentDescriptorArray* init(); + + RenderPipelineColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + + void reset(); + + void setObject(const MTL4::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; + +class RenderPipelineBinaryFunctionsDescriptor : public NS::Copying +{ +public: + static RenderPipelineBinaryFunctionsDescriptor* alloc(); + + NS::Array* fragmentAdditionalBinaryFunctions() const; + + RenderPipelineBinaryFunctionsDescriptor* init(); + + NS::Array* meshAdditionalBinaryFunctions() const; + + NS::Array* objectAdditionalBinaryFunctions() const; + + void reset(); + + void setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions); + + void setMeshAdditionalBinaryFunctions(const NS::Array* meshAdditionalBinaryFunctions); + + void setObjectAdditionalBinaryFunctions(const NS::Array* objectAdditionalBinaryFunctions); + + void setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions); + + void setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions); + + NS::Array* tileAdditionalBinaryFunctions() const; + + NS::Array* vertexAdditionalBinaryFunctions() const; +}; + +class RenderPipelineDescriptor : public NS::Copying +{ +public: + static RenderPipelineDescriptor* alloc(); + + AlphaToCoverageState alphaToCoverageState() const; + + AlphaToOneState alphaToOneState() const; + + LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + FunctionDescriptor* fragmentFunctionDescriptor() const; + + StaticLinkingDescriptor* fragmentStaticLinkingDescriptor() const; + + RenderPipelineDescriptor* init(); + + MTL::PrimitiveTopologyClass inputPrimitiveTopology() const; + + bool isRasterizationEnabled() const; + + NS::UInteger maxVertexAmplificationCount() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + void reset(); + + void setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState); + + void setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState); + + void setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState); + + void setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor); + + void setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor); + + void setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking); + + void setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers); + + void setSupportVertexBinaryLinking(bool supportVertexBinaryLinking); + + void setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor); + + void setVertexFunctionDescriptor(const MTL4::FunctionDescriptor* vertexFunctionDescriptor); + + void setVertexStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* vertexStaticLinkingDescriptor); + + bool supportFragmentBinaryLinking() const; + + IndirectCommandBufferSupportState supportIndirectCommandBuffers() const; + + bool supportVertexBinaryLinking() const; + + MTL::VertexDescriptor* vertexDescriptor() const; + + FunctionDescriptor* vertexFunctionDescriptor() const; + + StaticLinkingDescriptor* vertexStaticLinkingDescriptor() const; +}; + +} +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptor* MTL4::RenderPipelineColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::BlendOperation MTL4::RenderPipelineColorAttachmentDescriptor::alphaBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaBlendOperation)); +} + +_MTL_INLINE MTL4::BlendState MTL4::RenderPipelineColorAttachmentDescriptor::blendingState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(blendingState)); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::destinationAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::destinationRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationRGBBlendFactor)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptor* MTL4::RenderPipelineColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PixelFormat MTL4::RenderPipelineColorAttachmentDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE MTL::BlendOperation MTL4::RenderPipelineColorAttachmentDescriptor::rgbBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rgbBlendOperation)); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaBlendOperation_), alphaBlendOperation); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setBlendingState(MTL4::BlendState blendingState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendingState_), blendingState); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationAlphaBlendFactor_), destinationAlphaBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationRGBBlendFactor_), destinationRGBBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRgbBlendOperation_), rgbBlendOperation); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceAlphaBlendFactor_), sourceAlphaBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceRGBBlendFactor_), sourceRGBBlendFactor); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptor::setWriteMask(MTL::ColorWriteMask writeMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWriteMask_), writeMask); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::sourceAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL4::RenderPipelineColorAttachmentDescriptor::sourceRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceRGBBlendFactor)); +} + +_MTL_INLINE MTL::ColorWriteMask MTL4::RenderPipelineColorAttachmentDescriptor::writeMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(writeMask)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::RenderPipelineColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::RenderPipelineColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptor* MTL4::RenderPipelineColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptorArray::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::RenderPipelineColorAttachmentDescriptorArray::setObject(const MTL4::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL4::RenderPipelineBinaryFunctionsDescriptor* MTL4::RenderPipelineBinaryFunctionsDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineBinaryFunctionsDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::fragmentAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL4::RenderPipelineBinaryFunctionsDescriptor* MTL4::RenderPipelineBinaryFunctionsDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::meshAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshAdditionalBinaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::objectAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAdditionalBinaryFunctions)); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentAdditionalBinaryFunctions_), fragmentAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setMeshAdditionalBinaryFunctions(const NS::Array* meshAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshAdditionalBinaryFunctions_), meshAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setObjectAdditionalBinaryFunctions(const NS::Array* objectAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectAdditionalBinaryFunctions_), objectAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileAdditionalBinaryFunctions_), tileAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL4::RenderPipelineBinaryFunctionsDescriptor::setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAdditionalBinaryFunctions_), vertexAdditionalBinaryFunctions); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::tileAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileAdditionalBinaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL4::RenderPipelineBinaryFunctionsDescriptor::vertexAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL4::RenderPipelineDescriptor* MTL4::RenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4RenderPipelineDescriptor)); +} + +_MTL_INLINE MTL4::AlphaToCoverageState MTL4::RenderPipelineDescriptor::alphaToCoverageState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToCoverageState)); +} + +_MTL_INLINE MTL4::AlphaToOneState MTL4::RenderPipelineDescriptor::alphaToOneState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaToOneState)); +} + +_MTL_INLINE MTL4::LogicalToPhysicalColorAttachmentMappingState MTL4::RenderPipelineDescriptor::colorAttachmentMappingState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachmentMappingState)); +} + +_MTL_INLINE MTL4::RenderPipelineColorAttachmentDescriptorArray* MTL4::RenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::RenderPipelineDescriptor::fragmentFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::RenderPipelineDescriptor::fragmentStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentStaticLinkingDescriptor)); +} + +_MTL_INLINE MTL4::RenderPipelineDescriptor* MTL4::RenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PrimitiveTopologyClass MTL4::RenderPipelineDescriptor::inputPrimitiveTopology() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inputPrimitiveTopology)); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE NS::UInteger MTL4::RenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setAlphaToCoverageState(MTL4::AlphaToCoverageState alphaToCoverageState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageState_), alphaToCoverageState); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setAlphaToOneState(MTL4::AlphaToOneState alphaToOneState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneState_), alphaToOneState); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setColorAttachmentMappingState(MTL4::LogicalToPhysicalColorAttachmentMappingState colorAttachmentMappingState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMappingState_), colorAttachmentMappingState); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setFragmentFunctionDescriptor(const MTL4::FunctionDescriptor* fragmentFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunctionDescriptor_), fragmentFunctionDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setFragmentStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* fragmentStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentStaticLinkingDescriptor_), fragmentStaticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputPrimitiveTopology_), inputPrimitiveTopology); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setSupportFragmentBinaryLinking(bool supportFragmentBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportFragmentBinaryLinking_), supportFragmentBinaryLinking); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setSupportIndirectCommandBuffers(MTL4::IndirectCommandBufferSupportState supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setSupportVertexBinaryLinking(bool supportVertexBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportVertexBinaryLinking_), supportVertexBinaryLinking); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexDescriptor_), vertexDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setVertexFunctionDescriptor(const MTL4::FunctionDescriptor* vertexFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFunctionDescriptor_), vertexFunctionDescriptor); +} + +_MTL_INLINE void MTL4::RenderPipelineDescriptor::setVertexStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* vertexStaticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStaticLinkingDescriptor_), vertexStaticLinkingDescriptor); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::supportFragmentBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportFragmentBinaryLinking)); +} + +_MTL_INLINE MTL4::IndirectCommandBufferSupportState MTL4::RenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL4::RenderPipelineDescriptor::supportVertexBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportVertexBinaryLinking)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL4::RenderPipelineDescriptor::vertexDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexDescriptor)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::RenderPipelineDescriptor::vertexFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFunctionDescriptor)); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::RenderPipelineDescriptor::vertexStaticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStaticLinkingDescriptor)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp b/dist/include/metal_cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp new file mode 100644 index 0000000..57c0094 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4SpecializedFunctionDescriptor.hpp @@ -0,0 +1,100 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4SpecializedFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class SpecializedFunctionDescriptor; +} + +namespace MTL +{ +class FunctionConstantValues; +} + +namespace MTL4 +{ +class SpecializedFunctionDescriptor : public NS::Copying +{ +public: + static SpecializedFunctionDescriptor* alloc(); + + MTL::FunctionConstantValues* constantValues() const; + + FunctionDescriptor* functionDescriptor() const; + + SpecializedFunctionDescriptor* init(); + + void setConstantValues(const MTL::FunctionConstantValues* constantValues); + + void setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor); + + void setSpecializedName(const NS::String* specializedName); + NS::String* specializedName() const; +}; + +} +_MTL_INLINE MTL4::SpecializedFunctionDescriptor* MTL4::SpecializedFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4SpecializedFunctionDescriptor)); +} + +_MTL_INLINE MTL::FunctionConstantValues* MTL4::SpecializedFunctionDescriptor::constantValues() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantValues)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::SpecializedFunctionDescriptor::functionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptor)); +} + +_MTL_INLINE MTL4::SpecializedFunctionDescriptor* MTL4::SpecializedFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::SpecializedFunctionDescriptor::setConstantValues(const MTL::FunctionConstantValues* constantValues) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValues_), constantValues); +} + +_MTL_INLINE void MTL4::SpecializedFunctionDescriptor::setFunctionDescriptor(const MTL4::FunctionDescriptor* functionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptor_), functionDescriptor); +} + +_MTL_INLINE void MTL4::SpecializedFunctionDescriptor::setSpecializedName(const NS::String* specializedName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSpecializedName_), specializedName); +} + +_MTL_INLINE NS::String* MTL4::SpecializedFunctionDescriptor::specializedName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(specializedName)); +} diff --git a/dist/include/metal_cpp/Metal/MTL4StitchedFunctionDescriptor.hpp b/dist/include/metal_cpp/Metal/MTL4StitchedFunctionDescriptor.hpp new file mode 100644 index 0000000..ca8ea5c --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4StitchedFunctionDescriptor.hpp @@ -0,0 +1,86 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4StitchedFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL4 +{ +class StitchedFunctionDescriptor; +} + +namespace MTL +{ +class FunctionStitchingGraph; +} + +namespace MTL4 +{ +class StitchedFunctionDescriptor : public NS::Copying +{ +public: + static StitchedFunctionDescriptor* alloc(); + + NS::Array* functionDescriptors() const; + + MTL::FunctionStitchingGraph* functionGraph() const; + + StitchedFunctionDescriptor* init(); + + void setFunctionDescriptors(const NS::Array* functionDescriptors); + + void setFunctionGraph(const MTL::FunctionStitchingGraph* functionGraph); +}; + +} +_MTL_INLINE MTL4::StitchedFunctionDescriptor* MTL4::StitchedFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4StitchedFunctionDescriptor)); +} + +_MTL_INLINE NS::Array* MTL4::StitchedFunctionDescriptor::functionDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionDescriptors)); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL4::StitchedFunctionDescriptor::functionGraph() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionGraph)); +} + +_MTL_INLINE MTL4::StitchedFunctionDescriptor* MTL4::StitchedFunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL4::StitchedFunctionDescriptor::setFunctionDescriptors(const NS::Array* functionDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionDescriptors_), functionDescriptors); +} + +_MTL_INLINE void MTL4::StitchedFunctionDescriptor::setFunctionGraph(const MTL::FunctionStitchingGraph* functionGraph) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionGraph_), functionGraph); +} diff --git a/dist/include/metal_cpp/Metal/MTL4TileRenderPipeline.hpp b/dist/include/metal_cpp/Metal/MTL4TileRenderPipeline.hpp new file mode 100644 index 0000000..dc74f48 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTL4TileRenderPipeline.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTL4TileRenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4PipelineState.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL4 +{ +class FunctionDescriptor; +class StaticLinkingDescriptor; +class TileRenderPipelineDescriptor; +} + +namespace MTL +{ +class TileRenderPipelineColorAttachmentDescriptorArray; +} + +namespace MTL4 +{ +class TileRenderPipelineDescriptor : public NS::Copying +{ +public: + static TileRenderPipelineDescriptor* alloc(); + + MTL::TileRenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + TileRenderPipelineDescriptor* init(); + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::UInteger rasterSampleCount() const; + + MTL::Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor); + + void setSupportBinaryLinking(bool supportBinaryLinking); + + void setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize); + + void setTileFunctionDescriptor(const MTL4::FunctionDescriptor* tileFunctionDescriptor); + + StaticLinkingDescriptor* staticLinkingDescriptor() const; + + bool supportBinaryLinking() const; + + bool threadgroupSizeMatchesTileSize() const; + + FunctionDescriptor* tileFunctionDescriptor() const; +}; + +} +_MTL_INLINE MTL4::TileRenderPipelineDescriptor* MTL4::TileRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTL4TileRenderPipelineDescriptor)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL4::TileRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL4::TileRenderPipelineDescriptor* MTL4::TileRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL4::TileRenderPipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL4::TileRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE MTL::Size MTL4::TileRenderPipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setStaticLinkingDescriptor(const MTL4::StaticLinkingDescriptor* staticLinkingDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStaticLinkingDescriptor_), staticLinkingDescriptor); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setSupportBinaryLinking(bool supportBinaryLinking) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportBinaryLinking_), supportBinaryLinking); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupSizeMatchesTileSize_), threadgroupSizeMatchesTileSize); +} + +_MTL_INLINE void MTL4::TileRenderPipelineDescriptor::setTileFunctionDescriptor(const MTL4::FunctionDescriptor* tileFunctionDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileFunctionDescriptor_), tileFunctionDescriptor); +} + +_MTL_INLINE MTL4::StaticLinkingDescriptor* MTL4::TileRenderPipelineDescriptor::staticLinkingDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(staticLinkingDescriptor)); +} + +_MTL_INLINE bool MTL4::TileRenderPipelineDescriptor::supportBinaryLinking() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportBinaryLinking)); +} + +_MTL_INLINE bool MTL4::TileRenderPipelineDescriptor::threadgroupSizeMatchesTileSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupSizeMatchesTileSize)); +} + +_MTL_INLINE MTL4::FunctionDescriptor* MTL4::TileRenderPipelineDescriptor::tileFunctionDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileFunctionDescriptor)); +} diff --git a/dist/include/metal_cpp/Metal/MTLAccelerationStructure.hpp b/dist/include/metal_cpp/Metal/MTLAccelerationStructure.hpp new file mode 100644 index 0000000..d3457c3 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLAccelerationStructure.hpp @@ -0,0 +1,1887 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAccelerationStructure.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLStageInputOutputDescriptor.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class AccelerationStructureBoundingBoxGeometryDescriptor; +class AccelerationStructureCurveGeometryDescriptor; +class AccelerationStructureDescriptor; +class AccelerationStructureGeometryDescriptor; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor; +class AccelerationStructureMotionCurveGeometryDescriptor; +class AccelerationStructureMotionTriangleGeometryDescriptor; +class AccelerationStructureTriangleGeometryDescriptor; +class Buffer; +class IndirectInstanceAccelerationStructureDescriptor; +class InstanceAccelerationStructureDescriptor; +class MotionKeyframeData; +class PrimitiveAccelerationStructureDescriptor; +} + +namespace MTL +{ +_MTL_ENUM(NS::Integer, MatrixLayout) { + MatrixLayoutColumnMajor = 0, + MatrixLayoutRowMajor = 1, +}; + +_MTL_ENUM(uint32_t, MotionBorderMode) { + MotionBorderModeClamp = 0, + MotionBorderModeVanish = 1, +}; + +_MTL_ENUM(NS::Integer, CurveType) { + CurveTypeRound = 0, + CurveTypeFlat = 1, +}; + +_MTL_ENUM(NS::Integer, CurveBasis) { + CurveBasisBSpline = 0, + CurveBasisCatmullRom = 1, + CurveBasisLinear = 2, + CurveBasisBezier = 3, +}; + +_MTL_ENUM(NS::Integer, CurveEndCaps) { + CurveEndCapsNone = 0, + CurveEndCapsDisk = 1, + CurveEndCapsSphere = 2, +}; + +_MTL_ENUM(NS::UInteger, AccelerationStructureInstanceDescriptorType) { + AccelerationStructureInstanceDescriptorTypeDefault = 0, + AccelerationStructureInstanceDescriptorTypeUserID = 1, + AccelerationStructureInstanceDescriptorTypeMotion = 2, + AccelerationStructureInstanceDescriptorTypeIndirect = 3, + AccelerationStructureInstanceDescriptorTypeIndirectMotion = 4, +}; + +_MTL_ENUM(NS::Integer, TransformType) { + TransformTypePackedFloat4x3 = 0, + TransformTypeComponent = 1, +}; + +_MTL_OPTIONS(NS::UInteger, AccelerationStructureRefitOptions) { + AccelerationStructureRefitOptionVertexData = 1, + AccelerationStructureRefitOptionPerPrimitiveData = 1 << 1, +}; + +_MTL_OPTIONS(NS::UInteger, AccelerationStructureUsage) { + AccelerationStructureUsageNone = 0, + AccelerationStructureUsageRefit = 1, + AccelerationStructureUsagePreferFastBuild = 1 << 1, + AccelerationStructureUsageExtendedLimits = 1 << 2, + AccelerationStructureUsagePreferFastIntersection = 1 << 4, + AccelerationStructureUsageMinimizeMemory = 1 << 5, +}; + +_MTL_OPTIONS(uint32_t, AccelerationStructureInstanceOptions) { + AccelerationStructureInstanceOptionNone = 0, + AccelerationStructureInstanceOptionDisableTriangleCulling = 1, + AccelerationStructureInstanceOptionTriangleFrontFacingWindingCounterClockwise = 1 << 1, + AccelerationStructureInstanceOptionOpaque = 1 << 2, + AccelerationStructureInstanceOptionNonOpaque = 1 << 3, +}; + +struct AccelerationStructureInstanceDescriptor +{ + MTL::PackedFloat4x3 transformationMatrix; + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t accelerationStructureIndex; +} _MTL_PACKED; + +struct AccelerationStructureUserIDInstanceDescriptor +{ + MTL::PackedFloat4x3 transformationMatrix; + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t accelerationStructureIndex; + uint32_t userID; +} _MTL_PACKED; + +struct AccelerationStructureMotionInstanceDescriptor +{ + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t accelerationStructureIndex; + uint32_t userID; + uint32_t motionTransformsStartIndex; + uint32_t motionTransformsCount; + MTL::MotionBorderMode motionStartBorderMode; + MTL::MotionBorderMode motionEndBorderMode; + float motionStartTime; + float motionEndTime; +} _MTL_PACKED; + +struct IndirectAccelerationStructureInstanceDescriptor +{ + MTL::PackedFloat4x3 transformationMatrix; + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t userID; + MTL::ResourceID accelerationStructureID; +} _MTL_PACKED; + +struct IndirectAccelerationStructureMotionInstanceDescriptor +{ + MTL::AccelerationStructureInstanceOptions options; + uint32_t mask; + uint32_t intersectionFunctionTableOffset; + uint32_t userID; + MTL::ResourceID accelerationStructureID; + uint32_t motionTransformsStartIndex; + uint32_t motionTransformsCount; + MTL::MotionBorderMode motionStartBorderMode; + MTL::MotionBorderMode motionEndBorderMode; + float motionStartTime; + float motionEndTime; +} _MTL_PACKED; + +class AccelerationStructureDescriptor : public NS::Copying +{ +public: + static AccelerationStructureDescriptor* alloc(); + + AccelerationStructureDescriptor* init(); + + void setUsage(MTL::AccelerationStructureUsage usage); + AccelerationStructureUsage usage() const; +}; +class AccelerationStructureGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureGeometryDescriptor* alloc(); + + bool allowDuplicateIntersectionFunctionInvocation() const; + + AccelerationStructureGeometryDescriptor* init(); + + NS::UInteger intersectionFunctionTableOffset() const; + + NS::String* label() const; + + bool opaque() const; + + Buffer* primitiveDataBuffer() const; + NS::UInteger primitiveDataBufferOffset() const; + + NS::UInteger primitiveDataElementSize() const; + + NS::UInteger primitiveDataStride() const; + + void setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation); + + void setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset); + + void setLabel(const NS::String* label); + + void setOpaque(bool opaque); + + void setPrimitiveDataBuffer(const MTL::Buffer* primitiveDataBuffer); + void setPrimitiveDataBufferOffset(NS::UInteger primitiveDataBufferOffset); + + void setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize); + + void setPrimitiveDataStride(NS::UInteger primitiveDataStride); +}; +class PrimitiveAccelerationStructureDescriptor : public NS::Copying +{ +public: + static PrimitiveAccelerationStructureDescriptor* alloc(); + + static PrimitiveAccelerationStructureDescriptor* descriptor(); + NS::Array* geometryDescriptors() const; + + PrimitiveAccelerationStructureDescriptor* init(); + + MotionBorderMode motionEndBorderMode() const; + + float motionEndTime() const; + + NS::UInteger motionKeyframeCount() const; + + MotionBorderMode motionStartBorderMode() const; + + float motionStartTime() const; + + void setGeometryDescriptors(const NS::Array* geometryDescriptors); + + void setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode); + + void setMotionEndTime(float motionEndTime); + + void setMotionKeyframeCount(NS::UInteger motionKeyframeCount); + + void setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode); + + void setMotionStartTime(float motionStartTime); +}; +class AccelerationStructureTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureTriangleGeometryDescriptor* alloc(); + + static AccelerationStructureTriangleGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer); + void setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffer(const MTL::Buffer* vertexBuffer); + void setVertexBufferOffset(NS::UInteger vertexBufferOffset); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + Buffer* transformationMatrixBuffer() const; + NS::UInteger transformationMatrixBufferOffset() const; + + MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + Buffer* vertexBuffer() const; + NS::UInteger vertexBufferOffset() const; + + AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureBoundingBoxGeometryDescriptor* alloc(); + + Buffer* boundingBoxBuffer() const; + NS::UInteger boundingBoxBufferOffset() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + static AccelerationStructureBoundingBoxGeometryDescriptor* descriptor(); + + AccelerationStructureBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffer(const MTL::Buffer* boundingBoxBuffer); + void setBoundingBoxBufferOffset(NS::UInteger boundingBoxBufferOffset); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class MotionKeyframeData : public NS::Referencing +{ +public: + static MotionKeyframeData* alloc(); + + Buffer* buffer() const; + + static MotionKeyframeData* data(); + + MotionKeyframeData* init(); + + NS::UInteger offset() const; + + void setBuffer(const MTL::Buffer* buffer); + + void setOffset(NS::UInteger offset); +}; +class AccelerationStructureMotionTriangleGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionTriangleGeometryDescriptor* alloc(); + + static AccelerationStructureMotionTriangleGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureMotionTriangleGeometryDescriptor* init(); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer); + void setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset); + + void setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout); + + void setTriangleCount(NS::UInteger triangleCount); + + void setVertexBuffers(const NS::Array* vertexBuffers); + + void setVertexFormat(MTL::AttributeFormat vertexFormat); + + void setVertexStride(NS::UInteger vertexStride); + + Buffer* transformationMatrixBuffer() const; + NS::UInteger transformationMatrixBufferOffset() const; + + MatrixLayout transformationMatrixLayout() const; + + NS::UInteger triangleCount() const; + + NS::Array* vertexBuffers() const; + + AttributeFormat vertexFormat() const; + + NS::UInteger vertexStride() const; +}; +class AccelerationStructureMotionBoundingBoxGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionBoundingBoxGeometryDescriptor* alloc(); + + NS::Array* boundingBoxBuffers() const; + + NS::UInteger boundingBoxCount() const; + + NS::UInteger boundingBoxStride() const; + + static AccelerationStructureMotionBoundingBoxGeometryDescriptor* descriptor(); + + AccelerationStructureMotionBoundingBoxGeometryDescriptor* init(); + + void setBoundingBoxBuffers(const NS::Array* boundingBoxBuffers); + + void setBoundingBoxCount(NS::UInteger boundingBoxCount); + + void setBoundingBoxStride(NS::UInteger boundingBoxStride); +}; +class AccelerationStructureCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureCurveGeometryDescriptor* alloc(); + + Buffer* controlPointBuffer() const; + NS::UInteger controlPointBufferOffset() const; + + NS::UInteger controlPointCount() const; + + AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + CurveBasis curveBasis() const; + + CurveEndCaps curveEndCaps() const; + + CurveType curveType() const; + + static AccelerationStructureCurveGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureCurveGeometryDescriptor* init(); + + Buffer* radiusBuffer() const; + NS::UInteger radiusBufferOffset() const; + + AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffer(const MTL::Buffer* controlPointBuffer); + void setControlPointBufferOffset(NS::UInteger controlPointBufferOffset); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffer(const MTL::Buffer* radiusBuffer); + void setRadiusBufferOffset(NS::UInteger radiusBufferOffset); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class AccelerationStructureMotionCurveGeometryDescriptor : public NS::Copying +{ +public: + static AccelerationStructureMotionCurveGeometryDescriptor* alloc(); + + NS::Array* controlPointBuffers() const; + + NS::UInteger controlPointCount() const; + + AttributeFormat controlPointFormat() const; + + NS::UInteger controlPointStride() const; + + CurveBasis curveBasis() const; + + CurveEndCaps curveEndCaps() const; + + CurveType curveType() const; + + static AccelerationStructureMotionCurveGeometryDescriptor* descriptor(); + + Buffer* indexBuffer() const; + NS::UInteger indexBufferOffset() const; + + IndexType indexType() const; + + AccelerationStructureMotionCurveGeometryDescriptor* init(); + + NS::Array* radiusBuffers() const; + + AttributeFormat radiusFormat() const; + + NS::UInteger radiusStride() const; + + NS::UInteger segmentControlPointCount() const; + + NS::UInteger segmentCount() const; + + void setControlPointBuffers(const NS::Array* controlPointBuffers); + + void setControlPointCount(NS::UInteger controlPointCount); + + void setControlPointFormat(MTL::AttributeFormat controlPointFormat); + + void setControlPointStride(NS::UInteger controlPointStride); + + void setCurveBasis(MTL::CurveBasis curveBasis); + + void setCurveEndCaps(MTL::CurveEndCaps curveEndCaps); + + void setCurveType(MTL::CurveType curveType); + + void setIndexBuffer(const MTL::Buffer* indexBuffer); + void setIndexBufferOffset(NS::UInteger indexBufferOffset); + + void setIndexType(MTL::IndexType indexType); + + void setRadiusBuffers(const NS::Array* radiusBuffers); + + void setRadiusFormat(MTL::AttributeFormat radiusFormat); + + void setRadiusStride(NS::UInteger radiusStride); + + void setSegmentControlPointCount(NS::UInteger segmentControlPointCount); + + void setSegmentCount(NS::UInteger segmentCount); +}; +class InstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static InstanceAccelerationStructureDescriptor* alloc(); + + static InstanceAccelerationStructureDescriptor* descriptor(); + + InstanceAccelerationStructureDescriptor* init(); + + NS::UInteger instanceCount() const; + + Buffer* instanceDescriptorBuffer() const; + NS::UInteger instanceDescriptorBufferOffset() const; + + NS::UInteger instanceDescriptorStride() const; + + AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MatrixLayout instanceTransformationMatrixLayout() const; + + NS::Array* instancedAccelerationStructures() const; + + Buffer* motionTransformBuffer() const; + NS::UInteger motionTransformBufferOffset() const; + + NS::UInteger motionTransformCount() const; + + NS::UInteger motionTransformStride() const; + + TransformType motionTransformType() const; + + void setInstanceCount(NS::UInteger instanceCount); + + void setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer); + void setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setInstancedAccelerationStructures(const NS::Array* instancedAccelerationStructures); + + void setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer); + void setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset); + + void setMotionTransformCount(NS::UInteger motionTransformCount); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; +class IndirectInstanceAccelerationStructureDescriptor : public NS::Copying +{ +public: + static IndirectInstanceAccelerationStructureDescriptor* alloc(); + + static IndirectInstanceAccelerationStructureDescriptor* descriptor(); + + IndirectInstanceAccelerationStructureDescriptor* init(); + + Buffer* instanceCountBuffer() const; + NS::UInteger instanceCountBufferOffset() const; + + Buffer* instanceDescriptorBuffer() const; + NS::UInteger instanceDescriptorBufferOffset() const; + + NS::UInteger instanceDescriptorStride() const; + + AccelerationStructureInstanceDescriptorType instanceDescriptorType() const; + + MatrixLayout instanceTransformationMatrixLayout() const; + + NS::UInteger maxInstanceCount() const; + + NS::UInteger maxMotionTransformCount() const; + + Buffer* motionTransformBuffer() const; + NS::UInteger motionTransformBufferOffset() const; + + Buffer* motionTransformCountBuffer() const; + NS::UInteger motionTransformCountBufferOffset() const; + + NS::UInteger motionTransformStride() const; + + TransformType motionTransformType() const; + + void setInstanceCountBuffer(const MTL::Buffer* instanceCountBuffer); + void setInstanceCountBufferOffset(NS::UInteger instanceCountBufferOffset); + + void setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer); + void setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset); + + void setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride); + + void setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType); + + void setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout); + + void setMaxInstanceCount(NS::UInteger maxInstanceCount); + + void setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount); + + void setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer); + void setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset); + + void setMotionTransformCountBuffer(const MTL::Buffer* motionTransformCountBuffer); + void setMotionTransformCountBufferOffset(NS::UInteger motionTransformCountBufferOffset); + + void setMotionTransformStride(NS::UInteger motionTransformStride); + + void setMotionTransformType(MTL::TransformType motionTransformType); +}; +class AccelerationStructure : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + NS::UInteger size() const; +}; + +} + +_MTL_INLINE MTL::AccelerationStructureDescriptor* MTL::AccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureDescriptor* MTL::AccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureDescriptor::setUsage(MTL::AccelerationStructureUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUsage_), usage); +} + +_MTL_INLINE MTL::AccelerationStructureUsage MTL::AccelerationStructureDescriptor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE MTL::AccelerationStructureGeometryDescriptor* MTL::AccelerationStructureGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureGeometryDescriptor)); +} + +_MTL_INLINE bool MTL::AccelerationStructureGeometryDescriptor::allowDuplicateIntersectionFunctionInvocation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowDuplicateIntersectionFunctionInvocation)); +} + +_MTL_INLINE MTL::AccelerationStructureGeometryDescriptor* MTL::AccelerationStructureGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::intersectionFunctionTableOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(intersectionFunctionTableOffset)); +} + +_MTL_INLINE NS::String* MTL::AccelerationStructureGeometryDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::AccelerationStructureGeometryDescriptor::opaque() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(opaque)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureGeometryDescriptor::primitiveDataBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::primitiveDataBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::primitiveDataElementSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataElementSize)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureGeometryDescriptor::primitiveDataStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(primitiveDataStride)); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setAllowDuplicateIntersectionFunctionInvocation(bool allowDuplicateIntersectionFunctionInvocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowDuplicateIntersectionFunctionInvocation_), allowDuplicateIntersectionFunctionInvocation); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setIntersectionFunctionTableOffset(NS::UInteger intersectionFunctionTableOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTableOffset_), intersectionFunctionTableOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setOpaque(bool opaque) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaque_), opaque); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataBuffer(const MTL::Buffer* primitiveDataBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataBuffer_), primitiveDataBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataBufferOffset(NS::UInteger primitiveDataBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataBufferOffset_), primitiveDataBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataElementSize(NS::UInteger primitiveDataElementSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataElementSize_), primitiveDataElementSize); +} + +_MTL_INLINE void MTL::AccelerationStructureGeometryDescriptor::setPrimitiveDataStride(NS::UInteger primitiveDataStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrimitiveDataStride_), primitiveDataStride); +} + +_MTL_INLINE MTL::PrimitiveAccelerationStructureDescriptor* MTL::PrimitiveAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPrimitiveAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::PrimitiveAccelerationStructureDescriptor* MTL::PrimitiveAccelerationStructureDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLPrimitiveAccelerationStructureDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE NS::Array* MTL::PrimitiveAccelerationStructureDescriptor::geometryDescriptors() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(geometryDescriptors)); +} + +_MTL_INLINE MTL::PrimitiveAccelerationStructureDescriptor* MTL::PrimitiveAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::MotionBorderMode MTL::PrimitiveAccelerationStructureDescriptor::motionEndBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndBorderMode)); +} + +_MTL_INLINE float MTL::PrimitiveAccelerationStructureDescriptor::motionEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionEndTime)); +} + +_MTL_INLINE NS::UInteger MTL::PrimitiveAccelerationStructureDescriptor::motionKeyframeCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionKeyframeCount)); +} + +_MTL_INLINE MTL::MotionBorderMode MTL::PrimitiveAccelerationStructureDescriptor::motionStartBorderMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartBorderMode)); +} + +_MTL_INLINE float MTL::PrimitiveAccelerationStructureDescriptor::motionStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionStartTime)); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setGeometryDescriptors(const NS::Array* geometryDescriptors) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGeometryDescriptors_), geometryDescriptors); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionEndBorderMode(MTL::MotionBorderMode motionEndBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndBorderMode_), motionEndBorderMode); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionEndTime(float motionEndTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionEndTime_), motionEndTime); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionKeyframeCount(NS::UInteger motionKeyframeCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionKeyframeCount_), motionKeyframeCount); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionStartBorderMode(MTL::MotionBorderMode motionStartBorderMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartBorderMode_), motionStartBorderMode); +} + +_MTL_INLINE void MTL::PrimitiveAccelerationStructureDescriptor::setMotionStartTime(float motionStartTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionStartTime_), motionStartTime); +} + +_MTL_INLINE MTL::AccelerationStructureTriangleGeometryDescriptor* MTL::AccelerationStructureTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureTriangleGeometryDescriptor* MTL::AccelerationStructureTriangleGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureTriangleGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureTriangleGeometryDescriptor* MTL::AccelerationStructureTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBufferOffset_), transformationMatrixBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexBuffer(const MTL::Buffer* vertexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_), vertexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexBufferOffset(NS::UInteger vertexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBufferOffset_), vertexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBufferOffset)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::AccelerationStructureTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureTriangleGeometryDescriptor::vertexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::vertexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBufferOffset)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL::AccelerationStructureBoundingBoxGeometryDescriptor* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL::AccelerationStructureBoundingBoxGeometryDescriptor* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureBoundingBoxGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureBoundingBoxGeometryDescriptor* MTL::AccelerationStructureBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxBuffer(const MTL::Buffer* boundingBoxBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffer_), boundingBoxBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxBufferOffset(NS::UInteger boundingBoxBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBufferOffset_), boundingBoxBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL::AccelerationStructureBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL::MotionKeyframeData* MTL::MotionKeyframeData::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLMotionKeyframeData)); +} + +_MTL_INLINE MTL::Buffer* MTL::MotionKeyframeData::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE MTL::MotionKeyframeData* MTL::MotionKeyframeData::data() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLMotionKeyframeData), _MTL_PRIVATE_SEL(data)); +} + +_MTL_INLINE MTL::MotionKeyframeData* MTL::MotionKeyframeData::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::MotionKeyframeData::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE void MTL::MotionKeyframeData::setBuffer(const MTL::Buffer* buffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_), buffer); +} + +_MTL_INLINE void MTL::MotionKeyframeData::setOffset(NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOffset_), offset); +} + +_MTL_INLINE MTL::AccelerationStructureMotionTriangleGeometryDescriptor* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionTriangleGeometryDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionTriangleGeometryDescriptor* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionTriangleGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureMotionTriangleGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionTriangleGeometryDescriptor* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixBuffer(const MTL::Buffer* transformationMatrixBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBuffer_), transformationMatrixBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixBufferOffset(NS::UInteger transformationMatrixBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixBufferOffset_), transformationMatrixBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTransformationMatrixLayout(MTL::MatrixLayout transformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTransformationMatrixLayout_), transformationMatrixLayout); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setTriangleCount(NS::UInteger triangleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleCount_), triangleCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexBuffers(const NS::Array* vertexBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_), vertexBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexFormat(MTL::AttributeFormat vertexFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFormat_), vertexFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionTriangleGeometryDescriptor::setVertexStride(NS::UInteger vertexStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexStride_), vertexStride); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixBufferOffset)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::AccelerationStructureMotionTriangleGeometryDescriptor::transformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(transformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::triangleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(triangleCount)); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionTriangleGeometryDescriptor::vertexBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureMotionTriangleGeometryDescriptor::vertexFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionTriangleGeometryDescriptor::vertexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexStride)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionBoundingBoxGeometryDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxBuffers)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::boundingBoxStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(boundingBoxStride)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionBoundingBoxGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor* MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxBuffers(const NS::Array* boundingBoxBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxBuffers_), boundingBoxBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxCount(NS::UInteger boundingBoxCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxCount_), boundingBoxCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionBoundingBoxGeometryDescriptor::setBoundingBoxStride(NS::UInteger boundingBoxStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBoundingBoxStride_), boundingBoxStride); +} + +_MTL_INLINE MTL::AccelerationStructureCurveGeometryDescriptor* MTL::AccelerationStructureCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureCurveGeometryDescriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureCurveGeometryDescriptor::controlPointBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::controlPointBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL::AccelerationStructureCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL::AccelerationStructureCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL::AccelerationStructureCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL::AccelerationStructureCurveGeometryDescriptor* MTL::AccelerationStructureCurveGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureCurveGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureCurveGeometryDescriptor* MTL::AccelerationStructureCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureCurveGeometryDescriptor::radiusBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::radiusBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBufferOffset)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointBuffer(const MTL::Buffer* controlPointBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffer_), controlPointBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointBufferOffset(NS::UInteger controlPointBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBufferOffset_), controlPointBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusBuffer(const MTL::Buffer* radiusBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffer_), radiusBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusBufferOffset(NS::UInteger radiusBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBufferOffset_), radiusBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL::AccelerationStructureMotionCurveGeometryDescriptor* MTL::AccelerationStructureMotionCurveGeometryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionCurveGeometryDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointBuffers)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointCount)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::controlPointStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlPointStride)); +} + +_MTL_INLINE MTL::CurveBasis MTL::AccelerationStructureMotionCurveGeometryDescriptor::curveBasis() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveBasis)); +} + +_MTL_INLINE MTL::CurveEndCaps MTL::AccelerationStructureMotionCurveGeometryDescriptor::curveEndCaps() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveEndCaps)); +} + +_MTL_INLINE MTL::CurveType MTL::AccelerationStructureMotionCurveGeometryDescriptor::curveType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(curveType)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionCurveGeometryDescriptor* MTL::AccelerationStructureMotionCurveGeometryDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructureMotionCurveGeometryDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::Buffer* MTL::AccelerationStructureMotionCurveGeometryDescriptor::indexBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::indexBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferOffset)); +} + +_MTL_INLINE MTL::IndexType MTL::AccelerationStructureMotionCurveGeometryDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::AccelerationStructureMotionCurveGeometryDescriptor* MTL::AccelerationStructureMotionCurveGeometryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL::AccelerationStructureMotionCurveGeometryDescriptor::radiusBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusBuffers)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AccelerationStructureMotionCurveGeometryDescriptor::radiusFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusFormat)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::radiusStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(radiusStride)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::segmentControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentControlPointCount)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructureMotionCurveGeometryDescriptor::segmentCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(segmentCount)); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointBuffers(const NS::Array* controlPointBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointBuffers_), controlPointBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointCount(NS::UInteger controlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointCount_), controlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointFormat(MTL::AttributeFormat controlPointFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointFormat_), controlPointFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setControlPointStride(NS::UInteger controlPointStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlPointStride_), controlPointStride); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setCurveBasis(MTL::CurveBasis curveBasis) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveBasis_), curveBasis); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setCurveEndCaps(MTL::CurveEndCaps curveEndCaps) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveEndCaps_), curveEndCaps); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setCurveType(MTL::CurveType curveType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCurveType_), curveType); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setIndexBuffer(const MTL::Buffer* indexBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBuffer_), indexBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setIndexBufferOffset(NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferOffset_), indexBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusBuffers(const NS::Array* radiusBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusBuffers_), radiusBuffers); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusFormat(MTL::AttributeFormat radiusFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusFormat_), radiusFormat); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setRadiusStride(NS::UInteger radiusStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRadiusStride_), radiusStride); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentControlPointCount(NS::UInteger segmentControlPointCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentControlPointCount_), segmentControlPointCount); +} + +_MTL_INLINE void MTL::AccelerationStructureMotionCurveGeometryDescriptor::setSegmentCount(NS::UInteger segmentCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSegmentCount_), segmentCount); +} + +_MTL_INLINE MTL::InstanceAccelerationStructureDescriptor* MTL::InstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLInstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::InstanceAccelerationStructureDescriptor* MTL::InstanceAccelerationStructureDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLInstanceAccelerationStructureDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::InstanceAccelerationStructureDescriptor* MTL::InstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::instanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCount)); +} + +_MTL_INLINE MTL::Buffer* MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL::InstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::InstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE NS::Array* MTL::InstanceAccelerationStructureDescriptor::instancedAccelerationStructures() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instancedAccelerationStructures)); +} + +_MTL_INLINE MTL::Buffer* MTL::InstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::motionTransformBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::motionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCount)); +} + +_MTL_INLINE NS::UInteger MTL::InstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL::InstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceCount(NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCount_), instanceCount); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBufferOffset_), instanceDescriptorBufferOffset); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setInstancedAccelerationStructures(const NS::Array* instancedAccelerationStructures) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstancedAccelerationStructures_), instancedAccelerationStructures); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBufferOffset_), motionTransformBufferOffset); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformCount(NS::UInteger motionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCount_), motionTransformCount); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL::InstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +_MTL_INLINE MTL::IndirectInstanceAccelerationStructureDescriptor* MTL::IndirectInstanceAccelerationStructureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIndirectInstanceAccelerationStructureDescriptor)); +} + +_MTL_INLINE MTL::IndirectInstanceAccelerationStructureDescriptor* MTL::IndirectInstanceAccelerationStructureDescriptor::descriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLIndirectInstanceAccelerationStructureDescriptor), _MTL_PRIVATE_SEL(descriptor)); +} + +_MTL_INLINE MTL::IndirectInstanceAccelerationStructureDescriptor* MTL::IndirectInstanceAccelerationStructureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::instanceCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCountBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::instanceCountBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceCountBufferOffset)); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorStride)); +} + +_MTL_INLINE MTL::AccelerationStructureInstanceDescriptorType MTL::IndirectInstanceAccelerationStructureDescriptor::instanceDescriptorType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceDescriptorType)); +} + +_MTL_INLINE MTL::MatrixLayout MTL::IndirectInstanceAccelerationStructureDescriptor::instanceTransformationMatrixLayout() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(instanceTransformationMatrixLayout)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::maxInstanceCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxInstanceCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::maxMotionTransformCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxMotionTransformCount)); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformBufferOffset)); +} + +_MTL_INLINE MTL::Buffer* MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformCountBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCountBuffer)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformCountBufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformCountBufferOffset)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformStride)); +} + +_MTL_INLINE MTL::TransformType MTL::IndirectInstanceAccelerationStructureDescriptor::motionTransformType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(motionTransformType)); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceCountBuffer(const MTL::Buffer* instanceCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCountBuffer_), instanceCountBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceCountBufferOffset(NS::UInteger instanceCountBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceCountBufferOffset_), instanceCountBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorBuffer(const MTL::Buffer* instanceDescriptorBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBuffer_), instanceDescriptorBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorBufferOffset(NS::UInteger instanceDescriptorBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorBufferOffset_), instanceDescriptorBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorStride(NS::UInteger instanceDescriptorStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorStride_), instanceDescriptorStride); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceDescriptorType(MTL::AccelerationStructureInstanceDescriptorType instanceDescriptorType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceDescriptorType_), instanceDescriptorType); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setInstanceTransformationMatrixLayout(MTL::MatrixLayout instanceTransformationMatrixLayout) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstanceTransformationMatrixLayout_), instanceTransformationMatrixLayout); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMaxInstanceCount(NS::UInteger maxInstanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxInstanceCount_), maxInstanceCount); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMaxMotionTransformCount(NS::UInteger maxMotionTransformCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxMotionTransformCount_), maxMotionTransformCount); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformBuffer(const MTL::Buffer* motionTransformBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBuffer_), motionTransformBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformBufferOffset(NS::UInteger motionTransformBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformBufferOffset_), motionTransformBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformCountBuffer(const MTL::Buffer* motionTransformCountBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCountBuffer_), motionTransformCountBuffer); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformCountBufferOffset(NS::UInteger motionTransformCountBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformCountBufferOffset_), motionTransformCountBufferOffset); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformStride(NS::UInteger motionTransformStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformStride_), motionTransformStride); +} + +_MTL_INLINE void MTL::IndirectInstanceAccelerationStructureDescriptor::setMotionTransformType(MTL::TransformType motionTransformType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMotionTransformType_), motionTransformType); +} + +_MTL_INLINE MTL::ResourceID MTL::AccelerationStructure::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructure::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} diff --git a/dist/include/metal_cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp new file mode 100644 index 0000000..5f82344 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp @@ -0,0 +1,260 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAccelerationStructureCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAccelerationStructure.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class AccelerationStructure; +class AccelerationStructureDescriptor; +class AccelerationStructurePassDescriptor; +class AccelerationStructurePassSampleBufferAttachmentDescriptor; +class AccelerationStructurePassSampleBufferAttachmentDescriptorArray; +class Buffer; +class CounterSampleBuffer; +class Fence; +class Heap; +class Resource; + +class AccelerationStructureCommandEncoder : public NS::Referencing +{ +public: + void buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset); + + void copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure); + + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset); + void refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset, MTL::AccelerationStructureRefitOptions options); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void updateFence(const MTL::Fence* fence); + + void useHeap(const MTL::Heap* heap); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count); + + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage); + + void waitForFence(const MTL::Fence* fence); + + void writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset); + void writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset, MTL::DataType sizeDataType); +}; +class AccelerationStructurePassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static AccelerationStructurePassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + AccelerationStructurePassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class AccelerationStructurePassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static AccelerationStructurePassSampleBufferAttachmentDescriptorArray* alloc(); + + AccelerationStructurePassSampleBufferAttachmentDescriptorArray* init(); + + AccelerationStructurePassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class AccelerationStructurePassDescriptor : public NS::Copying +{ +public: + static AccelerationStructurePassDescriptor* accelerationStructurePassDescriptor(); + + static AccelerationStructurePassDescriptor* alloc(); + + AccelerationStructurePassDescriptor* init(); + + AccelerationStructurePassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; +}; + +} +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::buildAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(buildAccelerationStructure_descriptor_scratchBuffer_scratchBufferOffset_), accelerationStructure, descriptor, scratchBuffer, scratchBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::copyAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::copyAndCompactAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructure* destinationAccelerationStructure) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyAndCompactAccelerationStructure_toAccelerationStructure_), sourceAccelerationStructure, destinationAccelerationStructure); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer, scratchBufferOffset); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::refitAccelerationStructure(const MTL::AccelerationStructure* sourceAccelerationStructure, const MTL::AccelerationStructureDescriptor* descriptor, const MTL::AccelerationStructure* destinationAccelerationStructure, const MTL::Buffer* scratchBuffer, NS::UInteger scratchBufferOffset, MTL::AccelerationStructureRefitOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_options_), sourceAccelerationStructure, descriptor, destinationAccelerationStructure, scratchBuffer, scratchBufferOffset, options); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useHeap(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_), heap); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_), heaps, count); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_), resource, usage); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_), resources, count, usage); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_), accelerationStructure, buffer, offset); +} + +_MTL_INLINE void MTL::AccelerationStructureCommandEncoder::writeCompactedAccelerationStructureSize(const MTL::AccelerationStructure* accelerationStructure, const MTL::Buffer* buffer, NS::UInteger offset, MTL::DataType sizeDataType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_sizeDataType_), accelerationStructure, buffer, offset, sizeDataType); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray::setObject(const MTL::AccelerationStructurePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::AccelerationStructurePassDescriptor* MTL::AccelerationStructurePassDescriptor::accelerationStructurePassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassDescriptor), _MTL_PRIVATE_SEL(accelerationStructurePassDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructurePassDescriptor* MTL::AccelerationStructurePassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAccelerationStructurePassDescriptor)); +} + +_MTL_INLINE MTL::AccelerationStructurePassDescriptor* MTL::AccelerationStructurePassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::AccelerationStructurePassSampleBufferAttachmentDescriptorArray* MTL::AccelerationStructurePassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} diff --git a/dist/include/metal_cpp/Metal/MTLAccelerationStructureTypes.hpp b/dist/include/metal_cpp/Metal/MTLAccelerationStructureTypes.hpp new file mode 100644 index 0000000..a08b1e9 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLAccelerationStructureTypes.hpp @@ -0,0 +1,292 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAccelerationStructureTypes.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLDefines.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLStageInputOutputDescriptor.hpp" + +#include "../Foundation/Foundation.hpp" +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnested-anon-types" +struct PackedFloat3 +{ + PackedFloat3(); + PackedFloat3(float x, float y, float z); + + float& operator[](int idx); + float operator[](int idx) const; + + union + { + struct + { + float x; + float y; + float z; + }; + + float elements[3]; + }; +} _MTL_PACKED; +#pragma clang diagnostic pop + +struct PackedFloat4x3 +{ + PackedFloat4x3(); + PackedFloat4x3(const PackedFloat3& col0, const PackedFloat3& col1, const PackedFloat3& col2, const PackedFloat3& col3); + + PackedFloat3& operator[](int idx); + const PackedFloat3& operator[](int idx) const; + + PackedFloat3 columns[4]; +} _MTL_PACKED; + +struct AxisAlignedBoundingBox +{ + AxisAlignedBoundingBox(); + AxisAlignedBoundingBox(PackedFloat3 p); + AxisAlignedBoundingBox(PackedFloat3 min, PackedFloat3 max); + + PackedFloat3 min; + PackedFloat3 max; +} _MTL_PACKED; + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnested-anon-types" +struct PackedFloatQuaternion +{ + PackedFloatQuaternion(); + PackedFloatQuaternion(float x, float y, float z, float w); + + float& operator[](int idx); + const float& operator[](int idx) const; + + union + { + struct + { + float x; + float y; + float z; + float w; + }; + + float elements[4]; + }; + +} _MTL_PACKED; +#pragma clang diagnostic pop + +struct ComponentTransform +{ + PackedFloat3 scale; + PackedFloat3 shear; + PackedFloat3 pivot; + PackedFloatQuaternion rotation; + PackedFloat3 translation; +} _MTL_PACKED; + +} + +namespace MTL4 +{ + +struct BufferRange +{ + BufferRange() = default; + BufferRange(uint64_t bufferAddress); + BufferRange(uint64_t bufferAddress, uint64_t length); + + static MTL4::BufferRange Make(uint64_t bufferAddress, uint64_t length); + + uint64_t bufferAddress; + uint64_t length; +} _MTL_PACKED; + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat3::PackedFloat3() + : x(0.0f) + , y(0.0f) + , z(0.0f) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat3::PackedFloat3(float _x, float _y, float _z) + : x(_x) + , y(_y) + , z(_z) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE float& MTL::PackedFloat3::operator[](int idx) +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE float MTL::PackedFloat3::operator[](int idx) const +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat4x3::PackedFloat4x3() +{ + columns[0] = PackedFloat3(0.0f, 0.0f, 0.0f); + columns[1] = PackedFloat3(0.0f, 0.0f, 0.0f); + columns[2] = PackedFloat3(0.0f, 0.0f, 0.0f); + columns[3] = PackedFloat3(0.0f, 0.0f, 0.0f); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat4x3::PackedFloat4x3(const PackedFloat3& col0, const PackedFloat3& col1, const PackedFloat3& col2, const PackedFloat3& col3) +{ + columns[0] = col0; + columns[1] = col1; + columns[2] = col2; + columns[3] = col3; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloat3& MTL::PackedFloat4x3::operator[](int idx) +{ + return columns[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE const MTL::PackedFloat3& MTL::PackedFloat4x3::operator[](int idx) const +{ + return columns[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if __apple_build_version__ > 16000026 +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wnan-infinity-disabled" +#endif // __apple_build_version__ > 16000026 +_MTL_INLINE MTL::AxisAlignedBoundingBox::AxisAlignedBoundingBox() + : min(INFINITY, INFINITY, INFINITY) + , max(-INFINITY, -INFINITY, -INFINITY) +{ +} +#if __apple_build_version__ > 16000026 +#pragma clang diagnostic pop +#endif // if __apple_build_version__ > 16000026 + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::AxisAlignedBoundingBox::AxisAlignedBoundingBox(PackedFloat3 p) + : min(p) + , max(p) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::AxisAlignedBoundingBox::AxisAlignedBoundingBox(PackedFloat3 _min, PackedFloat3 _max) + : min(_min) + , max(_max) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloatQuaternion::PackedFloatQuaternion() + : x(0.0f) + , y(0.0f) + , z(0.0f) + , w(0.0f) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::PackedFloatQuaternion::PackedFloatQuaternion(float x, float y, float z, float w) + : x(x) + , y(y) + , z(z) + , w(w) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE float& MTL::PackedFloatQuaternion::operator[](int idx) +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE const float& MTL::PackedFloatQuaternion::operator[](int idx) const +{ + return elements[idx]; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL4::BufferRange::BufferRange(uint64_t bufferAddress) +: bufferAddress(bufferAddress) +, length(-1) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL4::BufferRange::BufferRange(uint64_t bufferAddress, uint64_t length) +: bufferAddress(bufferAddress) +, length(length) +{ +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL4::BufferRange MTL4::BufferRange::Make(uint64_t bufferAddress, uint64_t length) +{ + return MTL4::BufferRange(bufferAddress, length); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + diff --git a/dist/include/metal_cpp/Metal/MTLAllocation.hpp b/dist/include/metal_cpp/Metal/MTLAllocation.hpp new file mode 100644 index 0000000..ba20105 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLAllocation.hpp @@ -0,0 +1,40 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLAllocation.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Allocation : public NS::Referencing +{ +public: + NS::UInteger allocatedSize() const; +}; + +} +_MTL_INLINE NS::UInteger MTL::Allocation::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} diff --git a/dist/include/metal_cpp/Metal/MTLArgument.hpp b/dist/include/metal_cpp/Metal/MTLArgument.hpp new file mode 100644 index 0000000..f91bd91 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLArgument.hpp @@ -0,0 +1,787 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLArgument.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTensor.hpp" +#include "MTLTexture.hpp" + +namespace MTL +{ +class Argument; +class ArrayType; +class PointerType; +class StructMember; +class StructType; +class TensorExtents; +class TensorReferenceType; +class TextureReferenceType; +class Type; +_MTL_ENUM(NS::UInteger, IndexType) { + IndexTypeUInt16 = 0, + IndexTypeUInt32 = 1, +}; + +_MTL_ENUM(NS::Integer, BindingType) { + BindingTypeBuffer = 0, + BindingTypeThreadgroupMemory = 1, + BindingTypeTexture = 2, + BindingTypeSampler = 3, + BindingTypeImageblockData = 16, + BindingTypeImageblock = 17, + BindingTypeVisibleFunctionTable = 24, + BindingTypePrimitiveAccelerationStructure = 25, + BindingTypeInstanceAccelerationStructure = 26, + BindingTypeIntersectionFunctionTable = 27, + BindingTypeObjectPayload = 34, + BindingTypeTensor = 37, +}; + +_MTL_ENUM(NS::UInteger, ArgumentType) { + ArgumentTypeBuffer = 0, + ArgumentTypeThreadgroupMemory = 1, + ArgumentTypeTexture = 2, + ArgumentTypeSampler = 3, + ArgumentTypeImageblockData = 16, + ArgumentTypeImageblock = 17, + ArgumentTypeVisibleFunctionTable = 24, + ArgumentTypePrimitiveAccelerationStructure = 25, + ArgumentTypeInstanceAccelerationStructure = 26, + ArgumentTypeIntersectionFunctionTable = 27, +}; + +_MTL_ENUM(NS::UInteger, BindingAccess) { + BindingAccessReadOnly = 0, + BindingAccessReadWrite = 1, + BindingAccessWriteOnly = 2, + ArgumentAccessReadOnly = 0, + ArgumentAccessReadWrite = 1, + ArgumentAccessWriteOnly = 2, +}; + +class Type : public NS::Referencing +{ +public: + static Type* alloc(); + + DataType dataType() const; + + Type* init(); +}; +class StructMember : public NS::Referencing +{ +public: + static StructMember* alloc(); + + NS::UInteger argumentIndex() const; + + ArrayType* arrayType(); + + DataType dataType() const; + + StructMember* init(); + + NS::String* name() const; + + NS::UInteger offset() const; + + PointerType* pointerType(); + + StructType* structType(); + + TensorReferenceType* tensorReferenceType(); + + TextureReferenceType* textureReferenceType(); +}; +class StructType : public NS::Referencing +{ +public: + static StructType* alloc(); + + StructType* init(); + + StructMember* memberByName(const NS::String* name); + + NS::Array* members() const; +}; +class ArrayType : public NS::Referencing +{ +public: + static ArrayType* alloc(); + + NS::UInteger argumentIndexStride() const; + + NS::UInteger arrayLength() const; + + ArrayType* elementArrayType(); + + PointerType* elementPointerType(); + + StructType* elementStructType(); + + TensorReferenceType* elementTensorReferenceType(); + + TextureReferenceType* elementTextureReferenceType(); + + DataType elementType() const; + + ArrayType* init(); + + NS::UInteger stride() const; +}; +class PointerType : public NS::Referencing +{ +public: + BindingAccess access() const; + + NS::UInteger alignment() const; + + static PointerType* alloc(); + + NS::UInteger dataSize() const; + + ArrayType* elementArrayType(); + + bool elementIsArgumentBuffer() const; + + StructType* elementStructType(); + + DataType elementType() const; + + PointerType* init(); +}; +class TextureReferenceType : public NS::Referencing +{ +public: + BindingAccess access() const; + + static TextureReferenceType* alloc(); + + TextureReferenceType* init(); + + bool isDepthTexture() const; + + DataType textureDataType() const; + + TextureType textureType() const; +}; +class TensorReferenceType : public NS::Referencing +{ +public: + BindingAccess access() const; + + static TensorReferenceType* alloc(); + + TensorExtents* dimensions() const; + + DataType indexType() const; + + TensorReferenceType* init(); + + TensorDataType tensorDataType() const; +}; +class Argument : public NS::Referencing +{ +public: + BindingAccess access() const; + + [[deprecated("please use isActive instead")]] + bool active() const; + + static Argument* alloc(); + + NS::UInteger arrayLength() const; + + NS::UInteger bufferAlignment() const; + + NS::UInteger bufferDataSize() const; + + DataType bufferDataType() const; + + PointerType* bufferPointerType() const; + + StructType* bufferStructType() const; + + NS::UInteger index() const; + + Argument* init(); + + bool isActive() const; + + bool isDepthTexture() const; + + NS::String* name() const; + + DataType textureDataType() const; + + TextureType textureType() const; + + NS::UInteger threadgroupMemoryAlignment() const; + + NS::UInteger threadgroupMemoryDataSize() const; + + ArgumentType type() const; +}; +class Binding : public NS::Referencing +{ +public: + BindingAccess access() const; + + [[deprecated("please use isArgument instead")]] + bool argument() const; + + NS::UInteger index() const; + + bool isArgument() const; + + bool isUsed() const; + + NS::String* name() const; + + BindingType type() const; + + [[deprecated("please use isUsed instead")]] + bool used() const; +}; +class BufferBinding : public NS::Referencing +{ +public: + NS::UInteger bufferAlignment() const; + + NS::UInteger bufferDataSize() const; + + DataType bufferDataType() const; + + PointerType* bufferPointerType() const; + + StructType* bufferStructType() const; +}; +class ThreadgroupBinding : public NS::Referencing +{ +public: + NS::UInteger threadgroupMemoryAlignment() const; + + NS::UInteger threadgroupMemoryDataSize() const; +}; +class TextureBinding : public NS::Referencing +{ +public: + NS::UInteger arrayLength() const; + + [[deprecated("please use isDepthTexture instead")]] + bool depthTexture() const; + bool isDepthTexture() const; + + DataType textureDataType() const; + + TextureType textureType() const; +}; +class ObjectPayloadBinding : public NS::Referencing +{ +public: + NS::UInteger objectPayloadAlignment() const; + + NS::UInteger objectPayloadDataSize() const; +}; +class TensorBinding : public NS::Referencing +{ +public: + TensorExtents* dimensions() const; + + DataType indexType() const; + + TensorDataType tensorDataType() const; +}; + +} +_MTL_INLINE MTL::Type* MTL::Type::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLType)); +} + +_MTL_INLINE MTL::DataType MTL::Type::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::Type* MTL::Type::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::StructMember* MTL::StructMember::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStructMember)); +} + +_MTL_INLINE NS::UInteger MTL::StructMember::argumentIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentIndex)); +} + +_MTL_INLINE MTL::ArrayType* MTL::StructMember::arrayType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayType)); +} + +_MTL_INLINE MTL::DataType MTL::StructMember::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::StructMember* MTL::StructMember::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::StructMember::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE NS::UInteger MTL::StructMember::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE MTL::PointerType* MTL::StructMember::pointerType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::StructMember::structType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(structType)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::StructMember::tensorReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorReferenceType)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::StructMember::textureReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureReferenceType)); +} + +_MTL_INLINE MTL::StructType* MTL::StructType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStructType)); +} + +_MTL_INLINE MTL::StructType* MTL::StructType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::StructMember* MTL::StructType::memberByName(const NS::String* name) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(memberByName_), name); +} + +_MTL_INLINE NS::Array* MTL::StructType::members() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(members)); +} + +_MTL_INLINE MTL::ArrayType* MTL::ArrayType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArrayType)); +} + +_MTL_INLINE NS::UInteger MTL::ArrayType::argumentIndexStride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentIndexStride)); +} + +_MTL_INLINE NS::UInteger MTL::ArrayType::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE MTL::ArrayType* MTL::ArrayType::elementArrayType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementArrayType)); +} + +_MTL_INLINE MTL::PointerType* MTL::ArrayType::elementPointerType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementPointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::ArrayType::elementStructType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementStructType)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::ArrayType::elementTensorReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementTensorReferenceType)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::ArrayType::elementTextureReferenceType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementTextureReferenceType)); +} + +_MTL_INLINE MTL::DataType MTL::ArrayType::elementType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementType)); +} + +_MTL_INLINE MTL::ArrayType* MTL::ArrayType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::ArrayType::stride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stride)); +} + +_MTL_INLINE MTL::BindingAccess MTL::PointerType::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE NS::UInteger MTL::PointerType::alignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alignment)); +} + +_MTL_INLINE MTL::PointerType* MTL::PointerType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPointerType)); +} + +_MTL_INLINE NS::UInteger MTL::PointerType::dataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataSize)); +} + +_MTL_INLINE MTL::ArrayType* MTL::PointerType::elementArrayType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementArrayType)); +} + +_MTL_INLINE bool MTL::PointerType::elementIsArgumentBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementIsArgumentBuffer)); +} + +_MTL_INLINE MTL::StructType* MTL::PointerType::elementStructType() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementStructType)); +} + +_MTL_INLINE MTL::DataType MTL::PointerType::elementType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(elementType)); +} + +_MTL_INLINE MTL::PointerType* MTL::PointerType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BindingAccess MTL::TextureReferenceType::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::TextureReferenceType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTextureReferenceType)); +} + +_MTL_INLINE MTL::TextureReferenceType* MTL::TextureReferenceType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::TextureReferenceType::isDepthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE MTL::DataType MTL::TextureReferenceType::textureDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureDataType)); +} + +_MTL_INLINE MTL::TextureType MTL::TextureReferenceType::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::BindingAccess MTL::TensorReferenceType::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::TensorReferenceType::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTensorReferenceType)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorReferenceType::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE MTL::DataType MTL::TensorReferenceType::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::TensorReferenceType* MTL::TensorReferenceType::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TensorDataType MTL::TensorReferenceType::tensorDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorDataType)); +} + +_MTL_INLINE MTL::BindingAccess MTL::Argument::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE bool MTL::Argument::active() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE MTL::Argument* MTL::Argument::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArgument)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::bufferAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::bufferDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataSize)); +} + +_MTL_INLINE MTL::DataType MTL::Argument::bufferDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataType)); +} + +_MTL_INLINE MTL::PointerType* MTL::Argument::bufferPointerType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferPointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::Argument::bufferStructType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferStructType)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE MTL::Argument* MTL::Argument::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::Argument::isActive() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE bool MTL::Argument::isDepthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE NS::String* MTL::Argument::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::DataType MTL::Argument::textureDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureDataType)); +} + +_MTL_INLINE MTL::TextureType MTL::Argument::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::threadgroupMemoryAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::Argument::threadgroupMemoryDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryDataSize)); +} + +_MTL_INLINE MTL::ArgumentType MTL::Argument::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE MTL::BindingAccess MTL::Binding::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE bool MTL::Binding::argument() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isArgument)); +} + +_MTL_INLINE NS::UInteger MTL::Binding::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE bool MTL::Binding::isArgument() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isArgument)); +} + +_MTL_INLINE bool MTL::Binding::isUsed() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isUsed)); +} + +_MTL_INLINE NS::String* MTL::Binding::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::BindingType MTL::Binding::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE bool MTL::Binding::used() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isUsed)); +} + +_MTL_INLINE NS::UInteger MTL::BufferBinding::bufferAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::BufferBinding::bufferDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataSize)); +} + +_MTL_INLINE MTL::DataType MTL::BufferBinding::bufferDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferDataType)); +} + +_MTL_INLINE MTL::PointerType* MTL::BufferBinding::bufferPointerType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferPointerType)); +} + +_MTL_INLINE MTL::StructType* MTL::BufferBinding::bufferStructType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferStructType)); +} + +_MTL_INLINE NS::UInteger MTL::ThreadgroupBinding::threadgroupMemoryAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::ThreadgroupBinding::threadgroupMemoryDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryDataSize)); +} + +_MTL_INLINE NS::UInteger MTL::TextureBinding::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE bool MTL::TextureBinding::depthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE bool MTL::TextureBinding::isDepthTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthTexture)); +} + +_MTL_INLINE MTL::DataType MTL::TextureBinding::textureDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureDataType)); +} + +_MTL_INLINE MTL::TextureType MTL::TextureBinding::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE NS::UInteger MTL::ObjectPayloadBinding::objectPayloadAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectPayloadAlignment)); +} + +_MTL_INLINE NS::UInteger MTL::ObjectPayloadBinding::objectPayloadDataSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectPayloadDataSize)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorBinding::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE MTL::DataType MTL::TensorBinding::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::TensorDataType MTL::TensorBinding::tensorDataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorDataType)); +} diff --git a/dist/include/metal_cpp/Metal/MTLArgumentEncoder.hpp b/dist/include/metal_cpp/Metal/MTLArgumentEncoder.hpp new file mode 100644 index 0000000..83dbbc2 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLArgumentEncoder.hpp @@ -0,0 +1,235 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLArgumentEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLDepthStencil.hpp" + +namespace MTL +{ +class AccelerationStructure; +class ArgumentEncoder; +class Buffer; +class ComputePipelineState; +class Device; +class IndirectCommandBuffer; +class IntersectionFunctionTable; +class RenderPipelineState; +class SamplerState; +class Texture; +class VisibleFunctionTable; + +static const NS::UInteger AttributeStrideStatic = NS::UIntegerMax; + +class ArgumentEncoder : public NS::Referencing +{ +public: + NS::UInteger alignment() const; + + void* constantData(NS::UInteger index); + + Device* device() const; + + NS::UInteger encodedLength() const; + + NS::String* label() const; + + ArgumentEncoder* newArgumentEncoder(NS::UInteger index); + + void setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger index); + + void setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger offset); + void setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger startOffset, NS::UInteger arrayElement); + + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + + void setComputePipelineState(const MTL::ComputePipelineState* pipeline, NS::UInteger index); + void setComputePipelineStates(const MTL::ComputePipelineState* const pipelines[], NS::Range range); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState, NS::UInteger index); + void setDepthStencilStates(const MTL::DepthStencilState* const depthStencilStates[], NS::Range range); + + void setIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::UInteger index); + void setIndirectCommandBuffers(const MTL::IndirectCommandBuffer* const buffers[], NS::Range range); + + void setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger index); + void setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setLabel(const NS::String* label); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipeline, NS::UInteger index); + void setRenderPipelineStates(const MTL::RenderPipelineState* const pipelines[], NS::Range range); + + void setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + + void setTexture(const MTL::Texture* texture, NS::UInteger index); + void setTextures(const MTL::Texture* const textures[], NS::Range range); + + void setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger index); + void setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range); +}; + +} + +_MTL_INLINE NS::UInteger MTL::ArgumentEncoder::alignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alignment)); +} + +_MTL_INLINE void* MTL::ArgumentEncoder::constantData(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantDataAtIndex_), index); +} + +_MTL_INLINE MTL::Device* MTL::ArgumentEncoder::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentEncoder::encodedLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(encodedLength)); +} + +_MTL_INLINE NS::String* MTL::ArgumentEncoder::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::ArgumentEncoder::newArgumentEncoder(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderForBufferAtIndex_), index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAccelerationStructure_atIndex_), accelerationStructure, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentBuffer_offset_), argumentBuffer, offset); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setArgumentBuffer(const MTL::Buffer* argumentBuffer, NS::UInteger startOffset, NS::UInteger arrayElement) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentBuffer_startOffset_arrayElement_), argumentBuffer, startOffset, arrayElement); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setComputePipelineState(const MTL::ComputePipelineState* pipeline, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_atIndex_), pipeline, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setComputePipelineStates(const MTL::ComputePipelineState* const pipelines[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineStates_withRange_), pipelines, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setDepthStencilState(const MTL::DepthStencilState* depthStencilState, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_atIndex_), depthStencilState, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setDepthStencilStates(const MTL::DepthStencilState* const depthStencilStates[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilStates_withRange_), depthStencilStates, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndirectCommandBuffer_atIndex_), indirectCommandBuffer, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIndirectCommandBuffers(const MTL::IndirectCommandBuffer* const buffers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndirectCommandBuffers_withRange_), buffers, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTable_atIndex_), intersectionFunctionTable, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTables_withRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setRenderPipelineState(const MTL::RenderPipelineState* pipeline, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_atIndex_), pipeline, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setRenderPipelineStates(const MTL::RenderPipelineState* const pipelines[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineStates_withRange_), pipelines, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTable_atIndex_), visibleFunctionTable, index); +} + +_MTL_INLINE void MTL::ArgumentEncoder::setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTables_withRange_), visibleFunctionTables, range); +} diff --git a/dist/include/metal_cpp/Metal/MTLBinaryArchive.hpp b/dist/include/metal_cpp/Metal/MTLBinaryArchive.hpp new file mode 100644 index 0000000..c3f1689 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLBinaryArchive.hpp @@ -0,0 +1,152 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBinaryArchive.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class BinaryArchiveDescriptor; +class ComputePipelineDescriptor; +class Device; +class FunctionDescriptor; +class Library; +class MeshRenderPipelineDescriptor; +class RenderPipelineDescriptor; +class StitchedLibraryDescriptor; +class TileRenderPipelineDescriptor; +_MTL_ENUM(NS::UInteger, BinaryArchiveError) { + BinaryArchiveErrorNone = 0, + BinaryArchiveErrorInvalidFile = 1, + BinaryArchiveErrorUnexpectedElement = 2, + BinaryArchiveErrorCompilationFailure = 3, + BinaryArchiveErrorInternalError = 4, +}; + +_MTL_CONST(NS::ErrorDomain, BinaryArchiveDomain); +class BinaryArchiveDescriptor : public NS::Copying +{ +public: + static BinaryArchiveDescriptor* alloc(); + + BinaryArchiveDescriptor* init(); + + void setUrl(const NS::URL* url); + NS::URL* url() const; +}; +class BinaryArchive : public NS::Referencing +{ +public: + bool addComputePipelineFunctions(const MTL::ComputePipelineDescriptor* descriptor, NS::Error** error); + + bool addFunction(const MTL::FunctionDescriptor* descriptor, const MTL::Library* library, NS::Error** error); + + bool addLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error); + + bool addMeshRenderPipelineFunctions(const MTL::MeshRenderPipelineDescriptor* descriptor, NS::Error** error); + + bool addRenderPipelineFunctions(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error); + + bool addTileRenderPipelineFunctions(const MTL::TileRenderPipelineDescriptor* descriptor, NS::Error** error); + + Device* device() const; + + NS::String* label() const; + + bool serializeToURL(const NS::URL* url, NS::Error** error); + + void setLabel(const NS::String* label); +}; + +} +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, BinaryArchiveDomain); +_MTL_INLINE MTL::BinaryArchiveDescriptor* MTL::BinaryArchiveDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBinaryArchiveDescriptor)); +} + +_MTL_INLINE MTL::BinaryArchiveDescriptor* MTL::BinaryArchiveDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::BinaryArchiveDescriptor::setUrl(const NS::URL* url) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUrl_), url); +} + +_MTL_INLINE NS::URL* MTL::BinaryArchiveDescriptor::url() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(url)); +} + +_MTL_INLINE bool MTL::BinaryArchive::addComputePipelineFunctions(const MTL::ComputePipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addComputePipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addFunction(const MTL::FunctionDescriptor* descriptor, const MTL::Library* library, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addFunctionWithDescriptor_library_error_), descriptor, library, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addLibraryWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addMeshRenderPipelineFunctions(const MTL::MeshRenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addMeshRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addRenderPipelineFunctions(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE bool MTL::BinaryArchive::addTileRenderPipelineFunctions(const MTL::TileRenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(addTileRenderPipelineFunctionsWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Device* MTL::BinaryArchive::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::BinaryArchive::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::BinaryArchive::serializeToURL(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeToURL_error_), url, error); +} + +_MTL_INLINE void MTL::BinaryArchive::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/dist/include/metal_cpp/Metal/MTLBlitCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLBlitCommandEncoder.hpp new file mode 100644 index 0000000..319f05e --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLBlitCommandEncoder.hpp @@ -0,0 +1,226 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBlitCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class CounterSampleBuffer; +class Fence; +class IndirectCommandBuffer; +class Resource; +class Tensor; +class TensorExtents; +class Texture; + +_MTL_OPTIONS(NS::UInteger, BlitOption) { + BlitOptionNone = 0, + BlitOptionDepthFromDepthStencil = 1, + BlitOptionStencilFromDepthStencil = 1 << 1, + BlitOptionRowLinearPVRTC = 1 << 2, +}; + +class BlitCommandEncoder : public NS::Referencing +{ +public: + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options); + void copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size); + + void copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions); + + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options); + void copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount); + void copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture); + + void copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex); + + void fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value); + + void generateMipmaps(const MTL::Texture* texture); + + void getTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice, bool resetCounters, const MTL::Buffer* countersBuffer, NS::UInteger countersBufferOffset); + + void optimizeContentsForCPUAccess(const MTL::Texture* texture); + void optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeContentsForGPUAccess(const MTL::Texture* texture); + void optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range); + + void resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range); + + void resetTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice); + + void resolveCounters(const MTL::CounterSampleBuffer* sampleBuffer, NS::Range range, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void synchronizeResource(const MTL::Resource* resource); + + void synchronizeTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level); + + void updateFence(const MTL::Fence* fence); + + void waitForFence(const MTL::Fence* fence); +}; + +} +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_options_), sourceBuffer, sourceOffset, sourceBytesPerRow, sourceBytesPerImage, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin, options); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromBuffer(const MTL::Buffer* sourceBuffer, NS::UInteger sourceOffset, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger size) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_), sourceBuffer, sourceOffset, destinationBuffer, destinationOffset, size); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTensor(const MTL::Tensor* sourceTensor, const MTL::TensorExtents* sourceOrigin, const MTL::TensorExtents* sourceDimensions, const MTL::Tensor* destinationTensor, const MTL::TensorExtents* destinationOrigin, const MTL::TensorExtents* destinationDimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTensor_sourceOrigin_sourceDimensions_toTensor_destinationOrigin_destinationDimensions_), sourceTensor, sourceOrigin, sourceDimensions, destinationTensor, destinationOrigin, destinationDimensions); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset, NS::UInteger destinationBytesPerRow, NS::UInteger destinationBytesPerImage, MTL::BlitOption options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_options_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationBuffer, destinationOffset, destinationBytesPerRow, destinationBytesPerImage, options); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, NS::UInteger sliceCount, NS::UInteger levelCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_sourceSlice_sourceLevel_toTexture_destinationSlice_destinationLevel_sliceCount_levelCount_), sourceTexture, sourceSlice, sourceLevel, destinationTexture, destinationSlice, destinationLevel, sliceCount, levelCount); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyFromTexture(const MTL::Texture* sourceTexture, const MTL::Texture* destinationTexture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyFromTexture_toTexture_), sourceTexture, destinationTexture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::copyIndirectCommandBuffer(const MTL::IndirectCommandBuffer* source, NS::Range sourceRange, const MTL::IndirectCommandBuffer* destination, NS::UInteger destinationIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyIndirectCommandBuffer_sourceRange_destination_destinationIndex_), source, sourceRange, destination, destinationIndex); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::fillBuffer(const MTL::Buffer* buffer, NS::Range range, uint8_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(fillBuffer_range_value_), buffer, range, value); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::generateMipmaps(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(generateMipmapsForTexture_), texture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::getTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice, bool resetCounters, const MTL::Buffer* countersBuffer, NS::UInteger countersBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getTextureAccessCounters_region_mipLevel_slice_resetCounters_countersBuffer_countersBufferOffset_), texture, region, mipLevel, slice, resetCounters, countersBuffer, countersBufferOffset); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_), texture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForCPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForCPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_), texture); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeContentsForGPUAccess(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeContentsForGPUAccess_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::optimizeIndirectCommandBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizeIndirectCommandBuffer_withRange_), indirectCommandBuffer, range); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::resetCommandsInBuffer(const MTL::IndirectCommandBuffer* buffer, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetCommandsInBuffer_withRange_), buffer, range); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::resetTextureAccessCounters(const MTL::Texture* texture, MTL::Region region, NS::UInteger mipLevel, NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetTextureAccessCounters_region_mipLevel_slice_), texture, region, mipLevel, slice); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::resolveCounters(const MTL::CounterSampleBuffer* sampleBuffer, NS::Range range, const MTL::Buffer* destinationBuffer, NS::UInteger destinationOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounters_inRange_destinationBuffer_destinationOffset_), sampleBuffer, range, destinationBuffer, destinationOffset); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::synchronizeResource(const MTL::Resource* resource) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(synchronizeResource_), resource); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::synchronizeTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(synchronizeTexture_slice_level_), texture, slice, level); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::BlitCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} diff --git a/dist/include/metal_cpp/Metal/MTLBlitPass.hpp b/dist/include/metal_cpp/Metal/MTLBlitPass.hpp new file mode 100644 index 0000000..6b15e0b --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLBlitPass.hpp @@ -0,0 +1,154 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBlitPass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class BlitPassDescriptor; +class BlitPassSampleBufferAttachmentDescriptor; +class BlitPassSampleBufferAttachmentDescriptorArray; +class CounterSampleBuffer; + +class BlitPassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static BlitPassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + BlitPassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class BlitPassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static BlitPassSampleBufferAttachmentDescriptorArray* alloc(); + + BlitPassSampleBufferAttachmentDescriptorArray* init(); + + BlitPassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::BlitPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class BlitPassDescriptor : public NS::Copying +{ +public: + static BlitPassDescriptor* alloc(); + + static BlitPassDescriptor* blitPassDescriptor(); + + BlitPassDescriptor* init(); + + BlitPassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; +}; + +} +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptor* MTL::BlitPassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBlitPassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::BlitPassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptor* MTL::BlitPassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::BlitPassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::BlitPassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptorArray* MTL::BlitPassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBlitPassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptorArray* MTL::BlitPassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptor* MTL::BlitPassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::BlitPassSampleBufferAttachmentDescriptorArray::setObject(const MTL::BlitPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::BlitPassDescriptor* MTL::BlitPassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBlitPassDescriptor)); +} + +_MTL_INLINE MTL::BlitPassDescriptor* MTL::BlitPassDescriptor::blitPassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLBlitPassDescriptor), _MTL_PRIVATE_SEL(blitPassDescriptor)); +} + +_MTL_INLINE MTL::BlitPassDescriptor* MTL::BlitPassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BlitPassSampleBufferAttachmentDescriptorArray* MTL::BlitPassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} diff --git a/dist/include/metal_cpp/Metal/MTLBuffer.hpp b/dist/include/metal_cpp/Metal/MTLBuffer.hpp new file mode 100644 index 0000000..a93be1b --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLBuffer.hpp @@ -0,0 +1,119 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLGPUAddress.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" + +namespace MTL +{ +class Buffer; +class Device; +class Tensor; +class TensorDescriptor; +class Texture; +class TextureDescriptor; + +class Buffer : public NS::Referencing +{ +public: + void addDebugMarker(const NS::String* marker, NS::Range range); + + void* contents(); + + void didModifyRange(NS::Range range); + + GPUAddress gpuAddress() const; + + NS::UInteger length() const; + + Buffer* newRemoteBufferViewForDevice(const MTL::Device* device); + + Tensor* newTensor(const MTL::TensorDescriptor* descriptor, NS::UInteger offset, NS::Error** error); + + Texture* newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow); + + Buffer* remoteStorageBuffer() const; + + void removeAllDebugMarkers(); + + BufferSparseTier sparseBufferTier() const; +}; + +} +_MTL_INLINE void MTL::Buffer::addDebugMarker(const NS::String* marker, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addDebugMarker_range_), marker, range); +} + +_MTL_INLINE void* MTL::Buffer::contents() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(contents)); +} + +_MTL_INLINE void MTL::Buffer::didModifyRange(NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(didModifyRange_), range); +} + +_MTL_INLINE MTL::GPUAddress MTL::Buffer::gpuAddress() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuAddress)); +} + +_MTL_INLINE NS::UInteger MTL::Buffer::length() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(length)); +} + +_MTL_INLINE MTL::Buffer* MTL::Buffer::newRemoteBufferViewForDevice(const MTL::Device* device) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRemoteBufferViewForDevice_), device); +} + +_MTL_INLINE MTL::Tensor* MTL::Buffer::newTensor(const MTL::TensorDescriptor* descriptor, NS::UInteger offset, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTensorWithDescriptor_offset_error_), descriptor, offset, error); +} + +_MTL_INLINE MTL::Texture* MTL::Buffer::newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_offset_bytesPerRow_), descriptor, offset, bytesPerRow); +} + +_MTL_INLINE MTL::Buffer* MTL::Buffer::remoteStorageBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(remoteStorageBuffer)); +} + +_MTL_INLINE void MTL::Buffer::removeAllDebugMarkers() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllDebugMarkers)); +} + +_MTL_INLINE MTL::BufferSparseTier MTL::Buffer::sparseBufferTier() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseBufferTier)); +} diff --git a/dist/include/metal_cpp/Metal/MTLCaptureManager.hpp b/dist/include/metal_cpp/Metal/MTLCaptureManager.hpp new file mode 100644 index 0000000..a762241 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLCaptureManager.hpp @@ -0,0 +1,217 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCaptureManager.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class CaptureDescriptor; +class CaptureManager; +class CaptureScope; +class CommandQueue; +class Device; +} + +namespace MTL4 +{ +class CommandQueue; +} + +namespace MTL +{ +_MTL_ENUM(NS::Integer, CaptureError) { + CaptureErrorNotSupported = 1, + CaptureErrorAlreadyCapturing = 2, + CaptureErrorInvalidDescriptor = 3, +}; + +_MTL_ENUM(NS::Integer, CaptureDestination) { + CaptureDestinationDeveloperTools = 1, + CaptureDestinationGPUTraceDocument = 2, +}; + +class CaptureDescriptor : public NS::Copying +{ +public: + static CaptureDescriptor* alloc(); + + NS::Object* captureObject() const; + + CaptureDestination destination() const; + + CaptureDescriptor* init(); + + NS::URL* outputURL() const; + + void setCaptureObject(NS::Object* captureObject); + + void setDestination(MTL::CaptureDestination destination); + + void setOutputURL(const NS::URL* outputURL); +}; +class CaptureManager : public NS::Referencing +{ +public: + static CaptureManager* alloc(); + + CaptureScope* defaultCaptureScope() const; + + CaptureManager* init(); + + bool isCapturing() const; + + CaptureScope* newCaptureScope(const MTL::Device* device); + CaptureScope* newCaptureScope(const MTL::CommandQueue* commandQueue); + CaptureScope* newCaptureScope(const MTL4::CommandQueue* commandQueue); + + void setDefaultCaptureScope(const MTL::CaptureScope* defaultCaptureScope); + + static CaptureManager* sharedCaptureManager(); + + bool startCapture(const MTL::CaptureDescriptor* descriptor, NS::Error** error); + void startCapture(const MTL::Device* device); + void startCapture(const MTL::CommandQueue* commandQueue); + void startCapture(const MTL::CaptureScope* captureScope); + + void stopCapture(); + + bool supportsDestination(MTL::CaptureDestination destination); +}; + +} +_MTL_INLINE MTL::CaptureDescriptor* MTL::CaptureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCaptureDescriptor)); +} + +_MTL_INLINE NS::Object* MTL::CaptureDescriptor::captureObject() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(captureObject)); +} + +_MTL_INLINE MTL::CaptureDestination MTL::CaptureDescriptor::destination() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destination)); +} + +_MTL_INLINE MTL::CaptureDescriptor* MTL::CaptureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::URL* MTL::CaptureDescriptor::outputURL() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(outputURL)); +} + +_MTL_INLINE void MTL::CaptureDescriptor::setCaptureObject(NS::Object* captureObject) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCaptureObject_), captureObject); +} + +_MTL_INLINE void MTL::CaptureDescriptor::setDestination(MTL::CaptureDestination destination) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestination_), destination); +} + +_MTL_INLINE void MTL::CaptureDescriptor::setOutputURL(const NS::URL* outputURL) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOutputURL_), outputURL); +} + +_MTL_INLINE MTL::CaptureManager* MTL::CaptureManager::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCaptureManager)); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::defaultCaptureScope() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(defaultCaptureScope)); +} + +_MTL_INLINE MTL::CaptureManager* MTL::CaptureManager::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::CaptureManager::isCapturing() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isCapturing)); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::newCaptureScope(const MTL::Device* device) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCaptureScopeWithDevice_), device); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::newCaptureScope(const MTL::CommandQueue* commandQueue) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCaptureScopeWithCommandQueue_), commandQueue); +} + +_MTL_INLINE MTL::CaptureScope* MTL::CaptureManager::newCaptureScope(const MTL4::CommandQueue* commandQueue) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCaptureScopeWithMTL4CommandQueue_), commandQueue); +} + +_MTL_INLINE void MTL::CaptureManager::setDefaultCaptureScope(const MTL::CaptureScope* defaultCaptureScope) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDefaultCaptureScope_), defaultCaptureScope); +} + +_MTL_INLINE MTL::CaptureManager* MTL::CaptureManager::sharedCaptureManager() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLCaptureManager), _MTL_PRIVATE_SEL(sharedCaptureManager)); +} + +_MTL_INLINE bool MTL::CaptureManager::startCapture(const MTL::CaptureDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::CaptureManager::startCapture(const MTL::Device* device) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithDevice_), device); +} + +_MTL_INLINE void MTL::CaptureManager::startCapture(const MTL::CommandQueue* commandQueue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithCommandQueue_), commandQueue); +} + +_MTL_INLINE void MTL::CaptureManager::startCapture(const MTL::CaptureScope* captureScope) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(startCaptureWithScope_), captureScope); +} + +_MTL_INLINE void MTL::CaptureManager::stopCapture() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(stopCapture)); +} + +_MTL_INLINE bool MTL::CaptureManager::supportsDestination(MTL::CaptureDestination destination) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsDestination_), destination); +} diff --git a/dist/include/metal_cpp/Metal/MTLCaptureScope.hpp b/dist/include/metal_cpp/Metal/MTLCaptureScope.hpp new file mode 100644 index 0000000..96ade67 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLCaptureScope.hpp @@ -0,0 +1,91 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCaptureScope.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLDefines.hpp" +#include "MTLPrivate.hpp" + +#include "../Foundation/Foundation.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +class CaptureScope : public NS::Referencing +{ +public: + class Device* device() const; + + NS::String* label() const; + void setLabel(const NS::String* pLabel); + + class CommandQueue* commandQueue() const; + + void beginScope(); + void endScope(); +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::Device* MTL::CaptureScope::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE NS::String* MTL::CaptureScope::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE void MTL::CaptureScope::setLabel(const NS::String* pLabel) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), pLabel); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE MTL::CommandQueue* MTL::CaptureScope::commandQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandQueue)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE void MTL::CaptureScope::beginScope() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(beginScope)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTL_INLINE void MTL::CaptureScope::endScope() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endScope)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Metal/MTLCommandBuffer.hpp b/dist/include/metal_cpp/Metal/MTLCommandBuffer.hpp new file mode 100644 index 0000000..c504573 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLCommandBuffer.hpp @@ -0,0 +1,464 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include +#include + +#include + +namespace MTL +{ +class AccelerationStructureCommandEncoder; +class AccelerationStructurePassDescriptor; +class BlitCommandEncoder; +class BlitPassDescriptor; +class CommandBuffer; +class CommandBufferDescriptor; +class CommandQueue; +class ComputeCommandEncoder; +class ComputePassDescriptor; +class Device; +class Drawable; +class Event; +class LogContainer; +class LogState; +class ParallelRenderCommandEncoder; +class RenderCommandEncoder; +class RenderPassDescriptor; +class ResidencySet; +class ResourceStateCommandEncoder; +class ResourceStatePassDescriptor; +_MTL_ENUM(NS::UInteger, CommandBufferStatus) { + CommandBufferStatusNotEnqueued = 0, + CommandBufferStatusEnqueued = 1, + CommandBufferStatusCommitted = 2, + CommandBufferStatusScheduled = 3, + CommandBufferStatusCompleted = 4, + CommandBufferStatusError = 5, +}; + +_MTL_ENUM(NS::UInteger, CommandBufferError) { + CommandBufferErrorNone = 0, + CommandBufferErrorInternal = 1, + CommandBufferErrorTimeout = 2, + CommandBufferErrorPageFault = 3, + CommandBufferErrorBlacklisted = 4, + CommandBufferErrorAccessRevoked = 4, + CommandBufferErrorNotPermitted = 7, + CommandBufferErrorOutOfMemory = 8, + CommandBufferErrorInvalidResource = 9, + CommandBufferErrorMemoryless = 10, + CommandBufferErrorDeviceRemoved = 11, + CommandBufferErrorStackOverflow = 12, +}; + +_MTL_ENUM(NS::Integer, CommandEncoderErrorState) { + CommandEncoderErrorStateUnknown = 0, + CommandEncoderErrorStateCompleted = 1, + CommandEncoderErrorStateAffected = 2, + CommandEncoderErrorStatePending = 3, + CommandEncoderErrorStateFaulted = 4, +}; + +_MTL_ENUM(NS::UInteger, DispatchType) { + DispatchTypeSerial = 0, + DispatchTypeConcurrent = 1, +}; + +_MTL_OPTIONS(NS::UInteger, CommandBufferErrorOption) { + CommandBufferErrorOptionNone = 0, + CommandBufferErrorOptionEncoderExecutionStatus = 1, +}; + +using CommandBufferHandler = void (^)(CommandBuffer*); +using HandlerFunction = std::function; + +class CommandBufferDescriptor : public NS::Copying +{ +public: + static CommandBufferDescriptor* alloc(); + + CommandBufferErrorOption errorOptions() const; + + CommandBufferDescriptor* init(); + + LogState* logState() const; + + bool retainedReferences() const; + + void setErrorOptions(MTL::CommandBufferErrorOption errorOptions); + + void setLogState(const MTL::LogState* logState); + + void setRetainedReferences(bool retainedReferences); +}; +class CommandBufferEncoderInfo : public NS::Referencing +{ +public: + NS::Array* debugSignposts() const; + + CommandEncoderErrorState errorState() const; + + NS::String* label() const; +}; +class CommandBuffer : public NS::Referencing +{ +public: + CFTimeInterval GPUEndTime() const; + + CFTimeInterval GPUStartTime() const; + + AccelerationStructureCommandEncoder* accelerationStructureCommandEncoder(); + AccelerationStructureCommandEncoder* accelerationStructureCommandEncoder(const MTL::AccelerationStructurePassDescriptor* descriptor); + + void addCompletedHandler(const MTL::CommandBufferHandler block); + void addCompletedHandler(const MTL::HandlerFunction& function); + + void addScheduledHandler(const MTL::CommandBufferHandler block); + void addScheduledHandler(const MTL::HandlerFunction& function); + + BlitCommandEncoder* blitCommandEncoder(); + BlitCommandEncoder* blitCommandEncoder(const MTL::BlitPassDescriptor* blitPassDescriptor); + + CommandQueue* commandQueue() const; + + void commit(); + + ComputeCommandEncoder* computeCommandEncoder(const MTL::ComputePassDescriptor* computePassDescriptor); + ComputeCommandEncoder* computeCommandEncoder(); + ComputeCommandEncoder* computeCommandEncoder(MTL::DispatchType dispatchType); + + Device* device() const; + + void encodeSignalEvent(const MTL::Event* event, uint64_t value); + + void encodeWait(const MTL::Event* event, uint64_t value); + + void enqueue(); + + NS::Error* error() const; + CommandBufferErrorOption errorOptions() const; + + CFTimeInterval kernelEndTime() const; + + CFTimeInterval kernelStartTime() const; + + NS::String* label() const; + + LogContainer* logs() const; + + ParallelRenderCommandEncoder* parallelRenderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor); + + void popDebugGroup(); + + void presentDrawable(const MTL::Drawable* drawable); + void presentDrawableAfterMinimumDuration(const MTL::Drawable* drawable, CFTimeInterval duration); + + void presentDrawableAtTime(const MTL::Drawable* drawable, CFTimeInterval presentationTime); + + void pushDebugGroup(const NS::String* string); + + RenderCommandEncoder* renderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor); + + ResourceStateCommandEncoder* resourceStateCommandEncoder(); + ResourceStateCommandEncoder* resourceStateCommandEncoder(const MTL::ResourceStatePassDescriptor* resourceStatePassDescriptor); + + bool retainedReferences() const; + + void setLabel(const NS::String* label); + + CommandBufferStatus status() const; + + void useResidencySet(const MTL::ResidencySet* residencySet); + void useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void waitUntilCompleted(); + + void waitUntilScheduled(); +}; + +} +_MTL_INLINE MTL::CommandBufferDescriptor* MTL::CommandBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCommandBufferDescriptor)); +} + +_MTL_INLINE MTL::CommandBufferErrorOption MTL::CommandBufferDescriptor::errorOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(errorOptions)); +} + +_MTL_INLINE MTL::CommandBufferDescriptor* MTL::CommandBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogState* MTL::CommandBufferDescriptor::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE bool MTL::CommandBufferDescriptor::retainedReferences() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(retainedReferences)); +} + +_MTL_INLINE void MTL::CommandBufferDescriptor::setErrorOptions(MTL::CommandBufferErrorOption errorOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setErrorOptions_), errorOptions); +} + +_MTL_INLINE void MTL::CommandBufferDescriptor::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} + +_MTL_INLINE void MTL::CommandBufferDescriptor::setRetainedReferences(bool retainedReferences) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRetainedReferences_), retainedReferences); +} + +_MTL_INLINE NS::Array* MTL::CommandBufferEncoderInfo::debugSignposts() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(debugSignposts)); +} + +_MTL_INLINE MTL::CommandEncoderErrorState MTL::CommandBufferEncoderInfo::errorState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(errorState)); +} + +_MTL_INLINE NS::String* MTL::CommandBufferEncoderInfo::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::GPUEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUEndTime)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::GPUStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(GPUStartTime)); +} + +_MTL_INLINE MTL::AccelerationStructureCommandEncoder* MTL::CommandBuffer::accelerationStructureCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(accelerationStructureCommandEncoder)); +} + +_MTL_INLINE MTL::AccelerationStructureCommandEncoder* MTL::CommandBuffer::accelerationStructureCommandEncoder(const MTL::AccelerationStructurePassDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(accelerationStructureCommandEncoderWithDescriptor_), descriptor); +} + +_MTL_INLINE void MTL::CommandBuffer::addCompletedHandler(const MTL::CommandBufferHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addCompletedHandler_), block); +} + +_MTL_INLINE void MTL::CommandBuffer::addCompletedHandler(const MTL::HandlerFunction& function) +{ + __block HandlerFunction blockFunction = function; + addCompletedHandler(^(MTL::CommandBuffer* pCommandBuffer) { blockFunction(pCommandBuffer); }); +} + +_MTL_INLINE void MTL::CommandBuffer::addScheduledHandler(const MTL::CommandBufferHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addScheduledHandler_), block); +} + +_MTL_INLINE void MTL::CommandBuffer::addScheduledHandler(const MTL::HandlerFunction& function) +{ + __block HandlerFunction blockFunction = function; + addScheduledHandler(^(MTL::CommandBuffer* pCommandBuffer) { blockFunction(pCommandBuffer); }); +} + +_MTL_INLINE MTL::BlitCommandEncoder* MTL::CommandBuffer::blitCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(blitCommandEncoder)); +} + +_MTL_INLINE MTL::BlitCommandEncoder* MTL::CommandBuffer::blitCommandEncoder(const MTL::BlitPassDescriptor* blitPassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(blitCommandEncoderWithDescriptor_), blitPassDescriptor); +} + +_MTL_INLINE MTL::CommandQueue* MTL::CommandBuffer::commandQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandQueue)); +} + +_MTL_INLINE void MTL::CommandBuffer::commit() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit)); +} + +_MTL_INLINE MTL::ComputeCommandEncoder* MTL::CommandBuffer::computeCommandEncoder(const MTL::ComputePassDescriptor* computePassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoderWithDescriptor_), computePassDescriptor); +} + +_MTL_INLINE MTL::ComputeCommandEncoder* MTL::CommandBuffer::computeCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoder)); +} + +_MTL_INLINE MTL::ComputeCommandEncoder* MTL::CommandBuffer::computeCommandEncoder(MTL::DispatchType dispatchType) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeCommandEncoderWithDispatchType_), dispatchType); +} + +_MTL_INLINE MTL::Device* MTL::CommandBuffer::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::CommandBuffer::encodeSignalEvent(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(encodeSignalEvent_value_), event, value); +} + +_MTL_INLINE void MTL::CommandBuffer::encodeWait(const MTL::Event* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(encodeWaitForEvent_value_), event, value); +} + +_MTL_INLINE void MTL::CommandBuffer::enqueue() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(enqueue)); +} + +_MTL_INLINE NS::Error* MTL::CommandBuffer::error() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(error)); +} + +_MTL_INLINE MTL::CommandBufferErrorOption MTL::CommandBuffer::errorOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(errorOptions)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::kernelEndTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(kernelEndTime)); +} + +_MTL_INLINE CFTimeInterval MTL::CommandBuffer::kernelStartTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(kernelStartTime)); +} + +_MTL_INLINE NS::String* MTL::CommandBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::LogContainer* MTL::CommandBuffer::logs() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logs)); +} + +_MTL_INLINE MTL::ParallelRenderCommandEncoder* MTL::CommandBuffer::parallelRenderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parallelRenderCommandEncoderWithDescriptor_), renderPassDescriptor); +} + +_MTL_INLINE void MTL::CommandBuffer::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL::CommandBuffer::presentDrawable(const MTL::Drawable* drawable) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentDrawable_), drawable); +} + +_MTL_INLINE void MTL::CommandBuffer::presentDrawableAfterMinimumDuration(const MTL::Drawable* drawable, CFTimeInterval duration) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentDrawable_afterMinimumDuration_), drawable, duration); +} + +_MTL_INLINE void MTL::CommandBuffer::presentDrawableAtTime(const MTL::Drawable* drawable, CFTimeInterval presentationTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentDrawable_atTime_), drawable, presentationTime); +} + +_MTL_INLINE void MTL::CommandBuffer::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE MTL::RenderCommandEncoder* MTL::CommandBuffer::renderCommandEncoder(const MTL::RenderPassDescriptor* renderPassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoderWithDescriptor_), renderPassDescriptor); +} + +_MTL_INLINE MTL::ResourceStateCommandEncoder* MTL::CommandBuffer::resourceStateCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceStateCommandEncoder)); +} + +_MTL_INLINE MTL::ResourceStateCommandEncoder* MTL::CommandBuffer::resourceStateCommandEncoder(const MTL::ResourceStatePassDescriptor* resourceStatePassDescriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceStateCommandEncoderWithDescriptor_), resourceStatePassDescriptor); +} + +_MTL_INLINE bool MTL::CommandBuffer::retainedReferences() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(retainedReferences)); +} + +_MTL_INLINE void MTL::CommandBuffer::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::CommandBufferStatus MTL::CommandBuffer::status() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(status)); +} + +_MTL_INLINE void MTL::CommandBuffer::useResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySet_), residencySet); +} + +_MTL_INLINE void MTL::CommandBuffer::useResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL::CommandBuffer::waitUntilCompleted() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilCompleted)); +} + +_MTL_INLINE void MTL::CommandBuffer::waitUntilScheduled() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilScheduled)); +} diff --git a/dist/include/metal_cpp/Metal/MTLCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLCommandEncoder.hpp new file mode 100644 index 0000000..a230ff5 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLCommandEncoder.hpp @@ -0,0 +1,117 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Device; + +_MTL_OPTIONS(NS::UInteger, ResourceUsage) { + ResourceUsageRead = 1, + ResourceUsageWrite = 1 << 1, + ResourceUsageSample = 1 << 2, +}; + +_MTL_OPTIONS(NS::UInteger, BarrierScope) { + BarrierScopeBuffers = 1, + BarrierScopeTextures = 1 << 1, + BarrierScopeRenderTargets = 1 << 2, +}; + +_MTL_OPTIONS(NS::UInteger, Stages) { + StageVertex = 1, + StageFragment = 1 << 1, + StageTile = 1 << 2, + StageObject = 1 << 3, + StageMesh = 1 << 4, + StageResourceState = 1 << 26, + StageDispatch = 1 << 27, + StageBlit = 1 << 28, + StageAccelerationStructure = 1 << 29, + StageMachineLearning = 1 << 30, + StageAll = 9223372036854775807, +}; + +class CommandEncoder : public NS::Referencing +{ +public: + void barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages); + + Device* device() const; + + void endEncoding(); + + void insertDebugSignpost(const NS::String* string); + + NS::String* label() const; + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE void MTL::CommandEncoder::barrierAfterQueueStages(MTL::Stages afterQueueStages, MTL::Stages beforeStages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(barrierAfterQueueStages_beforeStages_), afterQueueStages, beforeStages); +} + +_MTL_INLINE MTL::Device* MTL::CommandEncoder::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::CommandEncoder::endEncoding() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endEncoding)); +} + +_MTL_INLINE void MTL::CommandEncoder::insertDebugSignpost(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(insertDebugSignpost_), string); +} + +_MTL_INLINE NS::String* MTL::CommandEncoder::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::CommandEncoder::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL::CommandEncoder::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE void MTL::CommandEncoder::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/dist/include/metal_cpp/Metal/MTLCommandQueue.hpp b/dist/include/metal_cpp/Metal/MTLCommandQueue.hpp new file mode 100644 index 0000000..5d3bf16 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLCommandQueue.hpp @@ -0,0 +1,158 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCommandQueue.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class CommandBuffer; +class CommandBufferDescriptor; +class CommandQueueDescriptor; +class Device; +class LogState; +class ResidencySet; + +class CommandQueue : public NS::Referencing +{ +public: + void addResidencySet(const MTL::ResidencySet* residencySet); + void addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + CommandBuffer* commandBuffer(); + CommandBuffer* commandBuffer(const MTL::CommandBufferDescriptor* descriptor); + CommandBuffer* commandBufferWithUnretainedReferences(); + + Device* device() const; + + void insertDebugCaptureBoundary(); + + NS::String* label() const; + + void removeResidencySet(const MTL::ResidencySet* residencySet); + void removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count); + + void setLabel(const NS::String* label); +}; +class CommandQueueDescriptor : public NS::Copying +{ +public: + static CommandQueueDescriptor* alloc(); + + CommandQueueDescriptor* init(); + + LogState* logState() const; + + NS::UInteger maxCommandBufferCount() const; + + void setLogState(const MTL::LogState* logState); + + void setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount); +}; + +} +_MTL_INLINE void MTL::CommandQueue::addResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySet_), residencySet); +} + +_MTL_INLINE void MTL::CommandQueue::addResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addResidencySets_count_), residencySets, count); +} + +_MTL_INLINE MTL::CommandBuffer* MTL::CommandQueue::commandBuffer() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBuffer)); +} + +_MTL_INLINE MTL::CommandBuffer* MTL::CommandQueue::commandBuffer(const MTL::CommandBufferDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBufferWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::CommandBuffer* MTL::CommandQueue::commandBufferWithUnretainedReferences() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBufferWithUnretainedReferences)); +} + +_MTL_INLINE MTL::Device* MTL::CommandQueue::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::CommandQueue::insertDebugCaptureBoundary() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(insertDebugCaptureBoundary)); +} + +_MTL_INLINE NS::String* MTL::CommandQueue::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::CommandQueue::removeResidencySet(const MTL::ResidencySet* residencySet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySet_), residencySet); +} + +_MTL_INLINE void MTL::CommandQueue::removeResidencySets(const MTL::ResidencySet* const residencySets[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeResidencySets_count_), residencySets, count); +} + +_MTL_INLINE void MTL::CommandQueue::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::CommandQueueDescriptor* MTL::CommandQueueDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCommandQueueDescriptor)); +} + +_MTL_INLINE MTL::CommandQueueDescriptor* MTL::CommandQueueDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogState* MTL::CommandQueueDescriptor::logState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(logState)); +} + +_MTL_INLINE NS::UInteger MTL::CommandQueueDescriptor::maxCommandBufferCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCommandBufferCount)); +} + +_MTL_INLINE void MTL::CommandQueueDescriptor::setLogState(const MTL::LogState* logState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLogState_), logState); +} + +_MTL_INLINE void MTL::CommandQueueDescriptor::setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCommandBufferCount_), maxCommandBufferCount); +} diff --git a/dist/include/metal_cpp/Metal/MTLComputeCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLComputeCommandEncoder.hpp new file mode 100644 index 0000000..2f555e5 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLComputeCommandEncoder.hpp @@ -0,0 +1,324 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLComputeCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandBuffer.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class AccelerationStructure; +class Buffer; +class ComputePipelineState; +class CounterSampleBuffer; +class Fence; +class Heap; +class IndirectCommandBuffer; +class IntersectionFunctionTable; +class Resource; +class SamplerState; +class Texture; +class VisibleFunctionTable; + +struct DispatchThreadgroupsIndirectArguments +{ + uint32_t threadgroupsPerGrid[3]; +} _MTL_PACKED; + +struct DispatchThreadsIndirectArguments +{ + uint32_t threadsPerGrid[3]; + uint32_t threadsPerThreadgroup[3]; +} _MTL_PACKED; + +struct StageInRegionIndirectArguments +{ + uint32_t stageInOrigin[3]; + uint32_t stageInSize[3]; +} _MTL_PACKED; + +class ComputeCommandEncoder : public NS::Referencing +{ +public: + void dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup); + void dispatchThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerThreadgroup); + + void dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup); + + DispatchType dispatchType() const; + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset); + + void memoryBarrier(MTL::BarrierScope scope); + void memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + void setBufferOffset(NS::UInteger offset, NS::UInteger index); + void setBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range); + + void setBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + void setBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index); + + void setComputePipelineState(const MTL::ComputePipelineState* state); + + void setImageblockWidth(NS::UInteger width, NS::UInteger height); + + void setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setStageInRegion(MTL::Region region); + void setStageInRegion(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void setTexture(const MTL::Texture* texture, NS::UInteger index); + void setTextures(const MTL::Texture* const textures[], NS::Range range); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger bufferIndex); + void setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range); + + void updateFence(const MTL::Fence* fence); + + void useHeap(const MTL::Heap* heap); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count); + + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage); + + void waitForFence(const MTL::Fence* fence); +}; + +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::dispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroups_threadsPerThreadgroup_), threadgroupsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::dispatchThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerThreadgroup_), indirectBuffer, indirectBufferOffset, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::dispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreads_threadsPerThreadgroup_), threadsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE MTL::DispatchType MTL::ComputeCommandEncoder::dispatchType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchType)); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_indirectBufferOffset_), indirectCommandbuffer, indirectRangeBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::memoryBarrier(MTL::BarrierScope scope) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithScope_), scope); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithResources_count_), resources, count); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferOffset_attributeStride_atIndex_), offset, stride, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_attributeStrides_withRange_), buffers, offsets, strides, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBytes_length_attributeStride_atIndex_), bytes, length, stride, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setComputePipelineState(const MTL::ComputePipelineState* state) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_), state); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setImageblockWidth(NS::UInteger width, NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockWidth_height_), width, height); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setStageInRegion(MTL::Region region) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInRegion_), region); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setStageInRegion(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInRegionWithIndirectBuffer_indirectBufferOffset_), indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setVisibleFunctionTable(const MTL::VisibleFunctionTable* visibleFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTable_atBufferIndex_), visibleFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::setVisibleFunctionTables(const MTL::VisibleFunctionTable* const visibleFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTables_withBufferRange_), visibleFunctionTables, range); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useHeap(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_), heap); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_), heaps, count); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_), resource, usage); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_), resources, count, usage); +} + +_MTL_INLINE void MTL::ComputeCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} diff --git a/dist/include/metal_cpp/Metal/MTLComputePass.hpp b/dist/include/metal_cpp/Metal/MTLComputePass.hpp new file mode 100644 index 0000000..fb34f7d --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLComputePass.hpp @@ -0,0 +1,169 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLComputePass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandBuffer.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class ComputePassDescriptor; +class ComputePassSampleBufferAttachmentDescriptor; +class ComputePassSampleBufferAttachmentDescriptorArray; +class CounterSampleBuffer; + +class ComputePassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static ComputePassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + ComputePassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class ComputePassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static ComputePassSampleBufferAttachmentDescriptorArray* alloc(); + + ComputePassSampleBufferAttachmentDescriptorArray* init(); + + ComputePassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::ComputePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class ComputePassDescriptor : public NS::Copying +{ +public: + static ComputePassDescriptor* alloc(); + + static ComputePassDescriptor* computePassDescriptor(); + + DispatchType dispatchType() const; + + ComputePassDescriptor* init(); + + ComputePassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; + + void setDispatchType(MTL::DispatchType dispatchType); +}; + +} +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptor* MTL::ComputePassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptor* MTL::ComputePassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::ComputePassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::ComputePassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptorArray* MTL::ComputePassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptorArray* MTL::ComputePassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptor* MTL::ComputePassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::ComputePassSampleBufferAttachmentDescriptorArray::setObject(const MTL::ComputePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::ComputePassDescriptor* MTL::ComputePassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePassDescriptor)); +} + +_MTL_INLINE MTL::ComputePassDescriptor* MTL::ComputePassDescriptor::computePassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLComputePassDescriptor), _MTL_PRIVATE_SEL(computePassDescriptor)); +} + +_MTL_INLINE MTL::DispatchType MTL::ComputePassDescriptor::dispatchType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchType)); +} + +_MTL_INLINE MTL::ComputePassDescriptor* MTL::ComputePassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ComputePassSampleBufferAttachmentDescriptorArray* MTL::ComputePassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} + +_MTL_INLINE void MTL::ComputePassDescriptor::setDispatchType(MTL::DispatchType dispatchType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDispatchType_), dispatchType); +} diff --git a/dist/include/metal_cpp/Metal/MTLComputePipeline.hpp b/dist/include/metal_cpp/Metal/MTLComputePipeline.hpp new file mode 100644 index 0000000..d200af7 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLComputePipeline.hpp @@ -0,0 +1,439 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLComputePipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPipeline.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class ComputePipelineDescriptor; +class ComputePipelineReflection; +class ComputePipelineState; +class Device; +class Function; +class FunctionHandle; +class IntersectionFunctionTable; +class IntersectionFunctionTableDescriptor; +class LinkedFunctions; +class PipelineBufferDescriptorArray; +class StageInputOutputDescriptor; +class VisibleFunctionTable; +class VisibleFunctionTableDescriptor; + +} +namespace MTL4 +{ +class BinaryFunction; + +} +namespace MTL +{ +class ComputePipelineReflection : public NS::Referencing +{ +public: + static ComputePipelineReflection* alloc(); + + NS::Array* arguments() const; + + NS::Array* bindings() const; + + ComputePipelineReflection* init(); +}; +class ComputePipelineDescriptor : public NS::Copying +{ +public: + static ComputePipelineDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + PipelineBufferDescriptorArray* buffers() const; + + Function* computeFunction() const; + + ComputePipelineDescriptor* init(); + + NS::Array* insertLibraries() const; + + NS::String* label() const; + + LinkedFunctions* linkedFunctions() const; + + NS::UInteger maxCallStackDepth() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::Array* preloadedLibraries() const; + + Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setComputeFunction(const MTL::Function* computeFunction); + + void setInsertLibraries(const NS::Array* insertLibraries); + + void setLabel(const NS::String* label); + + void setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions); + + void setMaxCallStackDepth(NS::UInteger maxCallStackDepth); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setPreloadedLibraries(const NS::Array* preloadedLibraries); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setStageInputDescriptor(const MTL::StageInputOutputDescriptor* stageInputDescriptor); + + void setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions); + + void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers); + + void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth); + + ShaderValidation shaderValidation() const; + + StageInputOutputDescriptor* stageInputDescriptor() const; + + bool supportAddingBinaryFunctions() const; + + bool supportIndirectCommandBuffers() const; + + bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const; +}; +class ComputePipelineState : public NS::Referencing +{ +public: + Device* device() const; + + FunctionHandle* functionHandle(const NS::String* name); + FunctionHandle* functionHandle(const MTL4::BinaryFunction* function); + FunctionHandle* functionHandle(const MTL::Function* function); + + ResourceID gpuResourceID() const; + + NS::UInteger imageblockMemoryLength(MTL::Size imageblockDimensions); + + NS::String* label() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + ComputePipelineState* newComputePipelineStateWithBinaryFunctions(const NS::Array* additionalBinaryFunctions, NS::Error** error); + ComputePipelineState* newComputePipelineState(const NS::Array* functions, NS::Error** error); + + IntersectionFunctionTable* newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor); + + VisibleFunctionTable* newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor); + + ComputePipelineReflection* reflection() const; + + Size requiredThreadsPerThreadgroup() const; + + ShaderValidation shaderValidation() const; + + NS::UInteger staticThreadgroupMemoryLength() const; + + bool supportIndirectCommandBuffers() const; + + NS::UInteger threadExecutionWidth() const; +}; + +} +_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePipelineReflection)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineReflection::arguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arguments)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineReflection::bindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bindings)); +} + +_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ComputePipelineDescriptor* MTL::ComputePipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLComputePipelineDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::ComputePipelineDescriptor::buffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffers)); +} + +_MTL_INLINE MTL::Function* MTL::ComputePipelineDescriptor::computeFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(computeFunction)); +} + +_MTL_INLINE MTL::ComputePipelineDescriptor* MTL::ComputePipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::insertLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(insertLibraries)); +} + +_MTL_INLINE NS::String* MTL::ComputePipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::ComputePipelineDescriptor::linkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(linkedFunctions)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineDescriptor::maxCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::preloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preloadedLibraries)); +} + +_MTL_INLINE MTL::Size MTL::ComputePipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setComputeFunction(const MTL::Function* computeFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputeFunction_), computeFunction); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setInsertLibraries(const NS::Array* insertLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInsertLibraries_), insertLibraries); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLinkedFunctions_), linkedFunctions); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setMaxCallStackDepth(NS::UInteger maxCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setPreloadedLibraries(const NS::Array* preloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreloadedLibraries_), preloadedLibraries); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setStageInputDescriptor(const MTL::StageInputOutputDescriptor* stageInputDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInputDescriptor_), stageInputDescriptor); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingBinaryFunctions_), supportAddingBinaryFunctions); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL::ComputePipelineDescriptor::setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadGroupSizeIsMultipleOfThreadExecutionWidth_), threadGroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::ComputePipelineDescriptor::stageInputDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stageInputDescriptor)); +} + +_MTL_INLINE bool MTL::ComputePipelineDescriptor::supportAddingBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingBinaryFunctions)); +} + +_MTL_INLINE bool MTL::ComputePipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL::ComputePipelineDescriptor::threadGroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadGroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE MTL::Device* MTL::ComputePipelineState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const NS::String* name) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithName_), name); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const MTL4::BinaryFunction* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithBinaryFunction_), function); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const MTL::Function* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithFunction_), function); +} + +_MTL_INLINE MTL::ResourceID MTL::ComputePipelineState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::imageblockMemoryLength(MTL::Size imageblockDimensions) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockMemoryLengthForDimensions_), imageblockDimensions); +} + +_MTL_INLINE NS::String* MTL::ComputePipelineState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::ComputePipelineState::newComputePipelineStateWithBinaryFunctions(const NS::Array* additionalBinaryFunctions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithBinaryFunctions_error_), additionalBinaryFunctions, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::ComputePipelineState::newComputePipelineState(const NS::Array* functions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithAdditionalBinaryFunctions_error_), functions, error); +} + +_MTL_INLINE MTL::IntersectionFunctionTable* MTL::ComputePipelineState::newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionTableWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::VisibleFunctionTable* MTL::ComputePipelineState::newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newVisibleFunctionTableWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineState::reflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflection)); +} + +_MTL_INLINE MTL::Size MTL::ComputePipelineState::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineState::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::staticThreadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(staticThreadgroupMemoryLength)); +} + +_MTL_INLINE bool MTL::ComputePipelineState::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE NS::UInteger MTL::ComputePipelineState::threadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadExecutionWidth)); +} diff --git a/dist/include/metal_cpp/Metal/MTLCounters.hpp b/dist/include/metal_cpp/Metal/MTLCounters.hpp new file mode 100644 index 0000000..6d655f1 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLCounters.hpp @@ -0,0 +1,243 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLCounters.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include + +namespace MTL +{ +class CounterSampleBufferDescriptor; +class CounterSet; +class Device; +_MTL_ENUM(NS::Integer, CounterSampleBufferError) { + CounterSampleBufferErrorOutOfMemory = 0, + CounterSampleBufferErrorInvalid = 1, + CounterSampleBufferErrorInternal = 2, +}; + +using CommonCounter = NS::String*; +using CommonCounterSet = NS::String*; + +static const NS::UInteger CounterErrorValue = static_cast(~0ULL); +static const NS::UInteger CounterDontSample = static_cast(-1); +_MTL_CONST(NS::ErrorDomain, CounterErrorDomain); +_MTL_CONST(CommonCounter, CommonCounterTimestamp); +_MTL_CONST(CommonCounter, CommonCounterTessellationInputPatches); +_MTL_CONST(CommonCounter, CommonCounterVertexInvocations); +_MTL_CONST(CommonCounter, CommonCounterPostTessellationVertexInvocations); +_MTL_CONST(CommonCounter, CommonCounterClipperInvocations); +_MTL_CONST(CommonCounter, CommonCounterClipperPrimitivesOut); +_MTL_CONST(CommonCounter, CommonCounterFragmentInvocations); +_MTL_CONST(CommonCounter, CommonCounterFragmentsPassed); +_MTL_CONST(CommonCounter, CommonCounterComputeKernelInvocations); +_MTL_CONST(CommonCounter, CommonCounterTotalCycles); +_MTL_CONST(CommonCounter, CommonCounterVertexCycles); +_MTL_CONST(CommonCounter, CommonCounterTessellationCycles); +_MTL_CONST(CommonCounter, CommonCounterPostTessellationVertexCycles); +_MTL_CONST(CommonCounter, CommonCounterFragmentCycles); +_MTL_CONST(CommonCounter, CommonCounterRenderTargetWriteCycles); +_MTL_CONST(CommonCounterSet, CommonCounterSetTimestamp); +_MTL_CONST(CommonCounterSet, CommonCounterSetStageUtilization); +_MTL_CONST(CommonCounterSet, CommonCounterSetStatistic); +struct CounterResultTimestamp +{ + uint64_t timestamp; +} _MTL_PACKED; + +struct CounterResultStageUtilization +{ + uint64_t totalCycles; + uint64_t vertexCycles; + uint64_t tessellationCycles; + uint64_t postTessellationVertexCycles; + uint64_t fragmentCycles; + uint64_t renderTargetCycles; +} _MTL_PACKED; + +struct CounterResultStatistic +{ + uint64_t tessellationInputPatches; + uint64_t vertexInvocations; + uint64_t postTessellationVertexInvocations; + uint64_t clipperInvocations; + uint64_t clipperPrimitivesOut; + uint64_t fragmentInvocations; + uint64_t fragmentsPassed; + uint64_t computeKernelInvocations; +} _MTL_PACKED; + +class Counter : public NS::Referencing +{ +public: + NS::String* name() const; +}; +class CounterSet : public NS::Referencing +{ +public: + NS::Array* counters() const; + + NS::String* name() const; +}; +class CounterSampleBufferDescriptor : public NS::Copying +{ +public: + static CounterSampleBufferDescriptor* alloc(); + + CounterSet* counterSet() const; + + CounterSampleBufferDescriptor* init(); + + NS::String* label() const; + + NS::UInteger sampleCount() const; + + void setCounterSet(const MTL::CounterSet* counterSet); + + void setLabel(const NS::String* label); + + void setSampleCount(NS::UInteger sampleCount); + + void setStorageMode(MTL::StorageMode storageMode); + StorageMode storageMode() const; +}; +class CounterSampleBuffer : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* label() const; + + NS::Data* resolveCounterRange(NS::Range range); + + NS::UInteger sampleCount() const; +}; + +} + +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, CounterErrorDomain); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTimestamp); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTessellationInputPatches); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterVertexInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterPostTessellationVertexInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterClipperInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterClipperPrimitivesOut); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterFragmentInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterFragmentsPassed); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterComputeKernelInvocations); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTotalCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterVertexCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterTessellationCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterPostTessellationVertexCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterFragmentCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounter, CommonCounterRenderTargetWriteCycles); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounterSet, CommonCounterSetTimestamp); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounterSet, CommonCounterSetStageUtilization); +_MTL_PRIVATE_DEF_CONST(MTL::CommonCounterSet, CommonCounterSetStatistic); + +_MTL_INLINE NS::String* MTL::Counter::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE NS::Array* MTL::CounterSet::counters() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(counters)); +} + +_MTL_INLINE NS::String* MTL::CounterSet::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::CounterSampleBufferDescriptor* MTL::CounterSampleBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCounterSampleBufferDescriptor)); +} + +_MTL_INLINE MTL::CounterSet* MTL::CounterSampleBufferDescriptor::counterSet() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(counterSet)); +} + +_MTL_INLINE MTL::CounterSampleBufferDescriptor* MTL::CounterSampleBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::CounterSampleBufferDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::CounterSampleBufferDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setCounterSet(const MTL::CounterSet* counterSet) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCounterSet_), counterSet); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setSampleCount(NS::UInteger sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE void MTL::CounterSampleBufferDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE MTL::StorageMode MTL::CounterSampleBufferDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::Device* MTL::CounterSampleBuffer::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::CounterSampleBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::Data* MTL::CounterSampleBuffer::resolveCounterRange(NS::Range range) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveCounterRange_), range); +} + +_MTL_INLINE NS::UInteger MTL::CounterSampleBuffer::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} diff --git a/dist/include/metal_cpp/Metal/MTLDataType.hpp b/dist/include/metal_cpp/Metal/MTLDataType.hpp new file mode 100644 index 0000000..f0e9b25 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLDataType.hpp @@ -0,0 +1,129 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDataType.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +_MTL_ENUM(NS::UInteger, DataType) { + DataTypeNone = 0, + DataTypeStruct = 1, + DataTypeArray = 2, + DataTypeFloat = 3, + DataTypeFloat2 = 4, + DataTypeFloat3 = 5, + DataTypeFloat4 = 6, + DataTypeFloat2x2 = 7, + DataTypeFloat2x3 = 8, + DataTypeFloat2x4 = 9, + DataTypeFloat3x2 = 10, + DataTypeFloat3x3 = 11, + DataTypeFloat3x4 = 12, + DataTypeFloat4x2 = 13, + DataTypeFloat4x3 = 14, + DataTypeFloat4x4 = 15, + DataTypeHalf = 16, + DataTypeHalf2 = 17, + DataTypeHalf3 = 18, + DataTypeHalf4 = 19, + DataTypeHalf2x2 = 20, + DataTypeHalf2x3 = 21, + DataTypeHalf2x4 = 22, + DataTypeHalf3x2 = 23, + DataTypeHalf3x3 = 24, + DataTypeHalf3x4 = 25, + DataTypeHalf4x2 = 26, + DataTypeHalf4x3 = 27, + DataTypeHalf4x4 = 28, + DataTypeInt = 29, + DataTypeInt2 = 30, + DataTypeInt3 = 31, + DataTypeInt4 = 32, + DataTypeUInt = 33, + DataTypeUInt2 = 34, + DataTypeUInt3 = 35, + DataTypeUInt4 = 36, + DataTypeShort = 37, + DataTypeShort2 = 38, + DataTypeShort3 = 39, + DataTypeShort4 = 40, + DataTypeUShort = 41, + DataTypeUShort2 = 42, + DataTypeUShort3 = 43, + DataTypeUShort4 = 44, + DataTypeChar = 45, + DataTypeChar2 = 46, + DataTypeChar3 = 47, + DataTypeChar4 = 48, + DataTypeUChar = 49, + DataTypeUChar2 = 50, + DataTypeUChar3 = 51, + DataTypeUChar4 = 52, + DataTypeBool = 53, + DataTypeBool2 = 54, + DataTypeBool3 = 55, + DataTypeBool4 = 56, + DataTypeTexture = 58, + DataTypeSampler = 59, + DataTypePointer = 60, + DataTypeR8Unorm = 62, + DataTypeR8Snorm = 63, + DataTypeR16Unorm = 64, + DataTypeR16Snorm = 65, + DataTypeRG8Unorm = 66, + DataTypeRG8Snorm = 67, + DataTypeRG16Unorm = 68, + DataTypeRG16Snorm = 69, + DataTypeRGBA8Unorm = 70, + DataTypeRGBA8Unorm_sRGB = 71, + DataTypeRGBA8Snorm = 72, + DataTypeRGBA16Unorm = 73, + DataTypeRGBA16Snorm = 74, + DataTypeRGB10A2Unorm = 75, + DataTypeRG11B10Float = 76, + DataTypeRGB9E5Float = 77, + DataTypeRenderPipeline = 78, + DataTypeComputePipeline = 79, + DataTypeIndirectCommandBuffer = 80, + DataTypeLong = 81, + DataTypeLong2 = 82, + DataTypeLong3 = 83, + DataTypeLong4 = 84, + DataTypeULong = 85, + DataTypeULong2 = 86, + DataTypeULong3 = 87, + DataTypeULong4 = 88, + DataTypeVisibleFunctionTable = 115, + DataTypeIntersectionFunctionTable = 116, + DataTypePrimitiveAccelerationStructure = 117, + DataTypeInstanceAccelerationStructure = 118, + DataTypeBFloat = 121, + DataTypeBFloat2 = 122, + DataTypeBFloat3 = 123, + DataTypeBFloat4 = 124, + DataTypeDepthStencilState = 139, + DataTypeTensor = 140, +}; + +} diff --git a/dist/include/metal_cpp/Metal/MTLDefines.hpp b/dist/include/metal_cpp/Metal/MTLDefines.hpp new file mode 100644 index 0000000..4260a2b --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLDefines.hpp @@ -0,0 +1,41 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Foundation/NSDefines.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTL_EXPORT _NS_EXPORT +#define _MTL_EXTERN _NS_EXTERN +#define _MTL_INLINE _NS_INLINE +#define _MTL_PACKED _NS_PACKED + +#define _MTL_CONST(type, name) _NS_CONST(type, name) +#define _MTL_ENUM(type, name) _NS_ENUM(type, name) +#define _MTL_OPTIONS(type, name) _NS_OPTIONS(type, name) + +#define _MTL_VALIDATE_SIZE(ns, name) _NS_VALIDATE_SIZE(ns, name) +#define _MTL_VALIDATE_ENUM(ns, name) _NS_VALIDATE_ENUM(ns, name) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Metal/MTLDepthStencil.hpp b/dist/include/metal_cpp/Metal/MTLDepthStencil.hpp new file mode 100644 index 0000000..e111617 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLDepthStencil.hpp @@ -0,0 +1,277 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDepthStencil.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class DepthStencilDescriptor; +class Device; +class StencilDescriptor; +_MTL_ENUM(NS::UInteger, CompareFunction) { + CompareFunctionNever = 0, + CompareFunctionLess = 1, + CompareFunctionEqual = 2, + CompareFunctionLessEqual = 3, + CompareFunctionGreater = 4, + CompareFunctionNotEqual = 5, + CompareFunctionGreaterEqual = 6, + CompareFunctionAlways = 7, +}; + +_MTL_ENUM(NS::UInteger, StencilOperation) { + StencilOperationKeep = 0, + StencilOperationZero = 1, + StencilOperationReplace = 2, + StencilOperationIncrementClamp = 3, + StencilOperationDecrementClamp = 4, + StencilOperationInvert = 5, + StencilOperationIncrementWrap = 6, + StencilOperationDecrementWrap = 7, +}; + +class StencilDescriptor : public NS::Copying +{ +public: + static StencilDescriptor* alloc(); + + StencilOperation depthFailureOperation() const; + + StencilOperation depthStencilPassOperation() const; + + StencilDescriptor* init(); + + uint32_t readMask() const; + + void setDepthFailureOperation(MTL::StencilOperation depthFailureOperation); + + void setDepthStencilPassOperation(MTL::StencilOperation depthStencilPassOperation); + + void setReadMask(uint32_t readMask); + + void setStencilCompareFunction(MTL::CompareFunction stencilCompareFunction); + + void setStencilFailureOperation(MTL::StencilOperation stencilFailureOperation); + + void setWriteMask(uint32_t writeMask); + + CompareFunction stencilCompareFunction() const; + + StencilOperation stencilFailureOperation() const; + + uint32_t writeMask() const; +}; +class DepthStencilDescriptor : public NS::Copying +{ +public: + static DepthStencilDescriptor* alloc(); + + StencilDescriptor* backFaceStencil() const; + + CompareFunction depthCompareFunction() const; + + [[deprecated("please use isDepthWriteEnabled instead")]] + bool depthWriteEnabled() const; + + StencilDescriptor* frontFaceStencil() const; + + DepthStencilDescriptor* init(); + + bool isDepthWriteEnabled() const; + + NS::String* label() const; + + void setBackFaceStencil(const MTL::StencilDescriptor* backFaceStencil); + + void setDepthCompareFunction(MTL::CompareFunction depthCompareFunction); + + void setDepthWriteEnabled(bool depthWriteEnabled); + + void setFrontFaceStencil(const MTL::StencilDescriptor* frontFaceStencil); + + void setLabel(const NS::String* label); +}; +class DepthStencilState : public NS::Referencing +{ +public: + Device* device() const; + + ResourceID gpuResourceID() const; + + NS::String* label() const; +}; + +} +_MTL_INLINE MTL::StencilDescriptor* MTL::StencilDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStencilDescriptor)); +} + +_MTL_INLINE MTL::StencilOperation MTL::StencilDescriptor::depthFailureOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthFailureOperation)); +} + +_MTL_INLINE MTL::StencilOperation MTL::StencilDescriptor::depthStencilPassOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthStencilPassOperation)); +} + +_MTL_INLINE MTL::StencilDescriptor* MTL::StencilDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE uint32_t MTL::StencilDescriptor::readMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(readMask)); +} + +_MTL_INLINE void MTL::StencilDescriptor::setDepthFailureOperation(MTL::StencilOperation depthFailureOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthFailureOperation_), depthFailureOperation); +} + +_MTL_INLINE void MTL::StencilDescriptor::setDepthStencilPassOperation(MTL::StencilOperation depthStencilPassOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilPassOperation_), depthStencilPassOperation); +} + +_MTL_INLINE void MTL::StencilDescriptor::setReadMask(uint32_t readMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setReadMask_), readMask); +} + +_MTL_INLINE void MTL::StencilDescriptor::setStencilCompareFunction(MTL::CompareFunction stencilCompareFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilCompareFunction_), stencilCompareFunction); +} + +_MTL_INLINE void MTL::StencilDescriptor::setStencilFailureOperation(MTL::StencilOperation stencilFailureOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilFailureOperation_), stencilFailureOperation); +} + +_MTL_INLINE void MTL::StencilDescriptor::setWriteMask(uint32_t writeMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWriteMask_), writeMask); +} + +_MTL_INLINE MTL::CompareFunction MTL::StencilDescriptor::stencilCompareFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilCompareFunction)); +} + +_MTL_INLINE MTL::StencilOperation MTL::StencilDescriptor::stencilFailureOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilFailureOperation)); +} + +_MTL_INLINE uint32_t MTL::StencilDescriptor::writeMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(writeMask)); +} + +_MTL_INLINE MTL::DepthStencilDescriptor* MTL::DepthStencilDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLDepthStencilDescriptor)); +} + +_MTL_INLINE MTL::StencilDescriptor* MTL::DepthStencilDescriptor::backFaceStencil() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(backFaceStencil)); +} + +_MTL_INLINE MTL::CompareFunction MTL::DepthStencilDescriptor::depthCompareFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthCompareFunction)); +} + +_MTL_INLINE bool MTL::DepthStencilDescriptor::depthWriteEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthWriteEnabled)); +} + +_MTL_INLINE MTL::StencilDescriptor* MTL::DepthStencilDescriptor::frontFaceStencil() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(frontFaceStencil)); +} + +_MTL_INLINE MTL::DepthStencilDescriptor* MTL::DepthStencilDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::DepthStencilDescriptor::isDepthWriteEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isDepthWriteEnabled)); +} + +_MTL_INLINE NS::String* MTL::DepthStencilDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setBackFaceStencil(const MTL::StencilDescriptor* backFaceStencil) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBackFaceStencil_), backFaceStencil); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setDepthCompareFunction(MTL::CompareFunction depthCompareFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthCompareFunction_), depthCompareFunction); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setDepthWriteEnabled(bool depthWriteEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthWriteEnabled_), depthWriteEnabled); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setFrontFaceStencil(const MTL::StencilDescriptor* frontFaceStencil) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFaceStencil_), frontFaceStencil); +} + +_MTL_INLINE void MTL::DepthStencilDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::Device* MTL::DepthStencilState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::ResourceID MTL::DepthStencilState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::String* MTL::DepthStencilState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} diff --git a/dist/include/metal_cpp/Metal/MTLDevice.hpp b/dist/include/metal_cpp/Metal/MTLDevice.hpp new file mode 100644 index 0000000..0e86739 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLDevice.hpp @@ -0,0 +1,1493 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDevice.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTL4Counters.hpp" +#include "MTLArgument.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTexture.hpp" +#include "MTLTypes.hpp" +#include +#include +#include + +#include +#include +#include + +namespace MTL +{ +class AccelerationStructure; +class AccelerationStructureDescriptor; +class Architecture; +class ArgumentDescriptor; +class ArgumentEncoder; +class BinaryArchive; +class BinaryArchiveDescriptor; +class Buffer; +class BufferBinding; +class CommandQueue; +class CommandQueueDescriptor; +class CompileOptions; +class ComputePipelineDescriptor; +class ComputePipelineReflection; +class ComputePipelineState; +class CounterSampleBuffer; +class CounterSampleBufferDescriptor; +class DepthStencilDescriptor; +class DepthStencilState; +class Device; +class DynamicLibrary; +class Event; +class Fence; +class Function; +class FunctionConstantValues; +class FunctionHandle; +class Heap; +class HeapDescriptor; +class IOCommandQueue; +class IOCommandQueueDescriptor; +class IOFileHandle; +class IndirectCommandBuffer; +class IndirectCommandBufferDescriptor; +class Library; +class LogState; +class LogStateDescriptor; +class MeshRenderPipelineDescriptor; +class RasterizationRateMap; +class RasterizationRateMapDescriptor; +struct Region; +class RenderPipelineDescriptor; +class RenderPipelineReflection; +class RenderPipelineState; +class ResidencySet; +class ResidencySetDescriptor; +class ResourceViewPoolDescriptor; +struct SamplePosition; +class SamplerDescriptor; +class SamplerState; +class SharedEvent; +class SharedEventHandle; +class SharedTextureHandle; +class StitchedLibraryDescriptor; +class Tensor; +class TensorDescriptor; +class Texture; +class TextureDescriptor; +class TextureViewPool; +class TileRenderPipelineDescriptor; + +} +namespace MTL4 +{ +class Archive; +class ArgumentTable; +class ArgumentTableDescriptor; +class BinaryFunction; +class CommandAllocator; +class CommandAllocatorDescriptor; +class CommandBuffer; +class CommandQueue; +class CommandQueueDescriptor; +class Compiler; +class CompilerDescriptor; +class CounterHeap; +class CounterHeapDescriptor; +class PipelineDataSetSerializer; +class PipelineDataSetSerializerDescriptor; + +} +namespace MTL +{ +_MTL_ENUM(NS::Integer, IOCompressionMethod) { + IOCompressionMethodZlib = 0, + IOCompressionMethodLZFSE = 1, + IOCompressionMethodLZ4 = 2, + IOCompressionMethodLZMA = 3, + IOCompressionMethodLZBitmap = 4, +}; + +_MTL_ENUM(NS::UInteger, FeatureSet) { + FeatureSet_iOS_GPUFamily1_v1 = 0, + FeatureSet_iOS_GPUFamily2_v1 = 1, + FeatureSet_iOS_GPUFamily1_v2 = 2, + FeatureSet_iOS_GPUFamily2_v2 = 3, + FeatureSet_iOS_GPUFamily3_v1 = 4, + FeatureSet_iOS_GPUFamily1_v3 = 5, + FeatureSet_iOS_GPUFamily2_v3 = 6, + FeatureSet_iOS_GPUFamily3_v2 = 7, + FeatureSet_iOS_GPUFamily1_v4 = 8, + FeatureSet_iOS_GPUFamily2_v4 = 9, + FeatureSet_iOS_GPUFamily3_v3 = 10, + FeatureSet_iOS_GPUFamily4_v1 = 11, + FeatureSet_iOS_GPUFamily1_v5 = 12, + FeatureSet_iOS_GPUFamily2_v5 = 13, + FeatureSet_iOS_GPUFamily3_v4 = 14, + FeatureSet_iOS_GPUFamily4_v2 = 15, + FeatureSet_iOS_GPUFamily5_v1 = 16, + FeatureSet_macOS_GPUFamily1_v1 = 10000, + FeatureSet_OSX_GPUFamily1_v1 = 10000, + FeatureSet_macOS_GPUFamily1_v2 = 10001, + FeatureSet_OSX_GPUFamily1_v2 = 10001, + FeatureSet_macOS_ReadWriteTextureTier2 = 10002, + FeatureSet_OSX_ReadWriteTextureTier2 = 10002, + FeatureSet_macOS_GPUFamily1_v3 = 10003, + FeatureSet_macOS_GPUFamily1_v4 = 10004, + FeatureSet_macOS_GPUFamily2_v1 = 10005, + FeatureSet_watchOS_GPUFamily1_v1 = 20000, + FeatureSet_WatchOS_GPUFamily1_v1 = 20000, + FeatureSet_watchOS_GPUFamily2_v1 = 20001, + FeatureSet_WatchOS_GPUFamily2_v1 = 20001, + FeatureSet_tvOS_GPUFamily1_v1 = 30000, + FeatureSet_TVOS_GPUFamily1_v1 = 30000, + FeatureSet_tvOS_GPUFamily1_v2 = 30001, + FeatureSet_tvOS_GPUFamily1_v3 = 30002, + FeatureSet_tvOS_GPUFamily2_v1 = 30003, + FeatureSet_tvOS_GPUFamily1_v4 = 30004, + FeatureSet_tvOS_GPUFamily2_v2 = 30005, +}; + +_MTL_ENUM(NS::Integer, GPUFamily) { + GPUFamilyApple1 = 1001, + GPUFamilyApple2 = 1002, + GPUFamilyApple3 = 1003, + GPUFamilyApple4 = 1004, + GPUFamilyApple5 = 1005, + GPUFamilyApple6 = 1006, + GPUFamilyApple7 = 1007, + GPUFamilyApple8 = 1008, + GPUFamilyApple9 = 1009, + GPUFamilyApple10 = 1010, + GPUFamilyMac1 = 2001, + GPUFamilyMac2 = 2002, + GPUFamilyCommon1 = 3001, + GPUFamilyCommon2 = 3002, + GPUFamilyCommon3 = 3003, + GPUFamilyMacCatalyst1 = 4001, + GPUFamilyMacCatalyst2 = 4002, + GPUFamilyMetal3 = 5001, + GPUFamilyMetal4 = 5002, +}; + +_MTL_ENUM(NS::UInteger, DeviceLocation) { + DeviceLocationBuiltIn = 0, + DeviceLocationSlot = 1, + DeviceLocationExternal = 2, + DeviceLocationUnspecified = NS::UIntegerMax, +}; + +_MTL_ENUM(NS::UInteger, ReadWriteTextureTier) { + ReadWriteTextureTierNone = 0, + ReadWriteTextureTier1 = 1, + ReadWriteTextureTier2 = 2, +}; + +_MTL_ENUM(NS::UInteger, ArgumentBuffersTier) { + ArgumentBuffersTier1 = 0, + ArgumentBuffersTier2 = 1, +}; + +_MTL_ENUM(NS::UInteger, SparseTextureRegionAlignmentMode) { + SparseTextureRegionAlignmentModeOutward = 0, + SparseTextureRegionAlignmentModeInward = 1, +}; + +_MTL_ENUM(NS::UInteger, CounterSamplingPoint) { + CounterSamplingPointAtStageBoundary = 0, + CounterSamplingPointAtDrawBoundary = 1, + CounterSamplingPointAtDispatchBoundary = 2, + CounterSamplingPointAtTileDispatchBoundary = 3, + CounterSamplingPointAtBlitBoundary = 4, +}; + +_MTL_OPTIONS(NS::UInteger, PipelineOption) { + PipelineOptionNone = 0, + PipelineOptionArgumentInfo = 1, + PipelineOptionBindingInfo = 1, + PipelineOptionBufferTypeInfo = 1 << 1, + PipelineOptionFailOnBinaryArchiveMiss = 1 << 2, +}; + +using DeviceNotificationName = NS::String*; +using DeviceNotificationHandlerBlock = void (^)(MTL::Device* pDevice, MTL::DeviceNotificationName notifyName); +using DeviceNotificationHandlerFunction = std::function; +using AutoreleasedComputePipelineReflection = MTL::ComputePipelineReflection*; +using AutoreleasedRenderPipelineReflection = MTL::RenderPipelineReflection*; +using NewLibraryCompletionHandler = void (^)(MTL::Library*, NS::Error*); +using NewLibraryCompletionHandlerFunction = std::function; +using NewRenderPipelineStateCompletionHandler = void (^)(MTL::RenderPipelineState*, NS::Error*); +using NewRenderPipelineStateCompletionHandlerFunction = std::function; +using NewRenderPipelineStateWithReflectionCompletionHandler = void (^)(MTL::RenderPipelineState*, MTL::RenderPipelineReflection*, NS::Error*); +using NewRenderPipelineStateWithReflectionCompletionHandlerFunction = std::function; +using NewComputePipelineStateCompletionHandler = void (^)(MTL::ComputePipelineState*, NS::Error*); +using NewComputePipelineStateCompletionHandlerFunction = std::function; +using NewComputePipelineStateWithReflectionCompletionHandler = void (^)(MTL::ComputePipelineState*, MTL::ComputePipelineReflection*, NS::Error*); +using NewComputePipelineStateWithReflectionCompletionHandlerFunction = std::function; +using Timestamp = std::uint64_t; + +_MTL_CONST(DeviceNotificationName, DeviceWasAddedNotification); +_MTL_CONST(DeviceNotificationName, DeviceRemovalRequestedNotification); +_MTL_CONST(DeviceNotificationName, DeviceWasRemovedNotification); +_MTL_CONST(NS::ErrorUserInfoKey, CommandBufferEncoderInfoErrorKey); +Device* CreateSystemDefaultDevice(); +NS::Array* CopyAllDevices(); +NS::Array* CopyAllDevicesWithObserver(NS::Object** pOutObserver, MTL::DeviceNotificationHandlerBlock handler); +NS::Array* CopyAllDevicesWithObserver(NS::Object** pOutObserver, const MTL::DeviceNotificationHandlerFunction& handler); +void RemoveDeviceObserver(const NS::Object* pObserver); +struct AccelerationStructureSizes +{ + NS::UInteger accelerationStructureSize; + NS::UInteger buildScratchBufferSize; + NS::UInteger refitScratchBufferSize; +} _MTL_PACKED; + +struct SizeAndAlign +{ + NS::UInteger size; + NS::UInteger align; +} _MTL_PACKED; + +class ArgumentDescriptor : public NS::Copying +{ +public: + BindingAccess access() const; + + static ArgumentDescriptor* alloc(); + + static ArgumentDescriptor* argumentDescriptor(); + + NS::UInteger arrayLength() const; + + NS::UInteger constantBlockAlignment() const; + + DataType dataType() const; + + NS::UInteger index() const; + + ArgumentDescriptor* init(); + + void setAccess(MTL::BindingAccess access); + + void setArrayLength(NS::UInteger arrayLength); + + void setConstantBlockAlignment(NS::UInteger constantBlockAlignment); + + void setDataType(MTL::DataType dataType); + + void setIndex(NS::UInteger index); + + void setTextureType(MTL::TextureType textureType); + TextureType textureType() const; +}; +class Architecture : public NS::Copying +{ +public: + static Architecture* alloc(); + + Architecture* init(); + + NS::String* name() const; +}; +class Device : public NS::Referencing +{ +public: + AccelerationStructureSizes accelerationStructureSizes(const MTL::AccelerationStructureDescriptor* descriptor); + + Architecture* architecture() const; + + bool areBarycentricCoordsSupported() const; + + bool areProgrammableSamplePositionsSupported() const; + + bool areRasterOrderGroupsSupported() const; + + ArgumentBuffersTier argumentBuffersSupport() const; + + [[deprecated("please use areBarycentricCoordsSupported instead")]] + bool barycentricCoordsSupported() const; + + void convertSparsePixelRegions(const MTL::Region* pixelRegions, MTL::Region* tileRegions, MTL::Size tileSize, MTL::SparseTextureRegionAlignmentMode mode, NS::UInteger numRegions); + + void convertSparseTileRegions(const MTL::Region* tileRegions, MTL::Region* pixelRegions, MTL::Size tileSize, NS::UInteger numRegions); + + NS::Array* counterSets() const; + + NS::UInteger currentAllocatedSize() const; + + [[deprecated("please use isDepth24Stencil8PixelFormatSupported instead")]] + bool depth24Stencil8PixelFormatSupported() const; + + FunctionHandle* functionHandle(const MTL::Function* function); + FunctionHandle* functionHandle(const MTL4::BinaryFunction* function); + + void getDefaultSamplePositions(MTL::SamplePosition* positions, NS::UInteger count); + + bool hasUnifiedMemory() const; + + [[deprecated("please use isHeadless instead")]] + bool headless() const; + + SizeAndAlign heapAccelerationStructureSizeAndAlign(NS::UInteger size); + SizeAndAlign heapAccelerationStructureSizeAndAlign(const MTL::AccelerationStructureDescriptor* descriptor); + + SizeAndAlign heapBufferSizeAndAlign(NS::UInteger length, MTL::ResourceOptions options); + + SizeAndAlign heapTextureSizeAndAlign(const MTL::TextureDescriptor* desc); + + bool isDepth24Stencil8PixelFormatSupported() const; + + bool isHeadless() const; + + bool isLowPower() const; + + bool isRemovable() const; + + DeviceLocation location() const; + NS::UInteger locationNumber() const; + + [[deprecated("please use isLowPower instead")]] + bool lowPower() const; + + NS::UInteger maxArgumentBufferSamplerCount() const; + + NS::UInteger maxBufferLength() const; + + NS::UInteger maxThreadgroupMemoryLength() const; + + Size maxThreadsPerThreadgroup() const; + + uint64_t maxTransferRate() const; + + NS::UInteger maximumConcurrentCompilationTaskCount() const; + + NS::UInteger minimumLinearTextureAlignmentForPixelFormat(MTL::PixelFormat format); + + NS::UInteger minimumTextureBufferAlignmentForPixelFormat(MTL::PixelFormat format); + + NS::String* name() const; + + AccelerationStructure* newAccelerationStructure(NS::UInteger size); + AccelerationStructure* newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor); + + MTL4::Archive* newArchive(const NS::URL* url, NS::Error** error); + + ArgumentEncoder* newArgumentEncoder(const NS::Array* arguments); + ArgumentEncoder* newArgumentEncoder(const MTL::BufferBinding* bufferBinding); + + MTL4::ArgumentTable* newArgumentTable(const MTL4::ArgumentTableDescriptor* descriptor, NS::Error** error); + + BinaryArchive* newBinaryArchive(const MTL::BinaryArchiveDescriptor* descriptor, NS::Error** error); + + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options); + Buffer* newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options); + Buffer* newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options, void (^deallocator)(void*, NS::UInteger)); + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options, MTL::SparsePageSize placementSparsePageSize); + + MTL4::CommandAllocator* newCommandAllocator(); + MTL4::CommandAllocator* newCommandAllocator(const MTL4::CommandAllocatorDescriptor* descriptor, NS::Error** error); + + MTL4::CommandBuffer* newCommandBuffer(); + + CommandQueue* newCommandQueue(); + CommandQueue* newCommandQueue(NS::UInteger maxCommandBufferCount); + CommandQueue* newCommandQueue(const MTL::CommandQueueDescriptor* descriptor); + + MTL4::Compiler* newCompiler(const MTL4::CompilerDescriptor* descriptor, NS::Error** error); + + ComputePipelineState* newComputePipelineState(const MTL::Function* computeFunction, NS::Error** error); + ComputePipelineState* newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error); + void newComputePipelineState(const MTL::Function* computeFunction, const MTL::NewComputePipelineStateCompletionHandler completionHandler); + void newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler); + ComputePipelineState* newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error); + void newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler); + void newComputePipelineState(const MTL::Function* pFunction, const MTL::NewComputePipelineStateCompletionHandlerFunction& completionHandler); + void newComputePipelineState(const MTL::Function* pFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + void newComputePipelineState(const MTL::ComputePipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + + MTL4::CounterHeap* newCounterHeap(const MTL4::CounterHeapDescriptor* descriptor, NS::Error** error); + + CounterSampleBuffer* newCounterSampleBuffer(const MTL::CounterSampleBufferDescriptor* descriptor, NS::Error** error); + + Library* newDefaultLibrary(); + Library* newDefaultLibrary(const NS::Bundle* bundle, NS::Error** error); + + DepthStencilState* newDepthStencilState(const MTL::DepthStencilDescriptor* descriptor); + + DynamicLibrary* newDynamicLibrary(const MTL::Library* library, NS::Error** error); + DynamicLibrary* newDynamicLibrary(const NS::URL* url, NS::Error** error); + + Event* newEvent(); + + Fence* newFence(); + + Heap* newHeap(const MTL::HeapDescriptor* descriptor); + + IOCommandQueue* newIOCommandQueue(const MTL::IOCommandQueueDescriptor* descriptor, NS::Error** error); + + IOFileHandle* newIOFileHandle(const NS::URL* url, NS::Error** error); + IOFileHandle* newIOFileHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error); + + IOFileHandle* newIOHandle(const NS::URL* url, NS::Error** error); + IOFileHandle* newIOHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error); + + IndirectCommandBuffer* newIndirectCommandBuffer(const MTL::IndirectCommandBufferDescriptor* descriptor, NS::UInteger maxCount, MTL::ResourceOptions options); + + Library* newLibrary(const NS::String* filepath, NS::Error** error); + Library* newLibrary(const NS::URL* url, NS::Error** error); + Library* newLibrary(const dispatch_data_t data, NS::Error** error); + Library* newLibrary(const NS::String* source, const MTL::CompileOptions* options, NS::Error** error); + void newLibrary(const NS::String* source, const MTL::CompileOptions* options, const MTL::NewLibraryCompletionHandler completionHandler); + Library* newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error); + void newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler); + void newLibrary(const NS::String* pSource, const MTL::CompileOptions* pOptions, const MTL::NewLibraryCompletionHandlerFunction& completionHandler); + void newLibrary(const MTL::StitchedLibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& completionHandler); + + LogState* newLogState(const MTL::LogStateDescriptor* descriptor, NS::Error** error); + + MTL4::CommandQueue* newMTL4CommandQueue(); + MTL4::CommandQueue* newMTL4CommandQueue(const MTL4::CommandQueueDescriptor* descriptor, NS::Error** error); + + MTL4::PipelineDataSetSerializer* newPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializerDescriptor* descriptor); + + RasterizationRateMap* newRasterizationRateMap(const MTL::RasterizationRateMapDescriptor* descriptor); + + RenderPipelineState* newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error); + RenderPipelineState* newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, const MTL::NewRenderPipelineStateCompletionHandler completionHandler); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler); + RenderPipelineState* newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error); + void newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler); + RenderPipelineState* newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error); + void newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, const MTL::NewRenderPipelineStateCompletionHandlerFunction& completionHandler); + void newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + void newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler); + + ResidencySet* newResidencySet(const MTL::ResidencySetDescriptor* desc, NS::Error** error); + + SamplerState* newSamplerState(const MTL::SamplerDescriptor* descriptor); + + SharedEvent* newSharedEvent(); + SharedEvent* newSharedEvent(const MTL::SharedEventHandle* sharedEventHandle); + + Texture* newSharedTexture(const MTL::TextureDescriptor* descriptor); + Texture* newSharedTexture(const MTL::SharedTextureHandle* sharedHandle); + + Tensor* newTensor(const MTL::TensorDescriptor* descriptor, NS::Error** error); + + Texture* newTexture(const MTL::TextureDescriptor* descriptor); + Texture* newTexture(const MTL::TextureDescriptor* descriptor, const IOSurfaceRef iosurface, NS::UInteger plane); + TextureViewPool* newTextureViewPool(const MTL::ResourceViewPoolDescriptor* descriptor, NS::Error** error); + + uint32_t peerCount() const; + + uint64_t peerGroupID() const; + + uint32_t peerIndex() const; + + [[deprecated("please use areProgrammableSamplePositionsSupported instead")]] + bool programmableSamplePositionsSupported() const; + + uint64_t queryTimestampFrequency(); + + [[deprecated("please use areRasterOrderGroupsSupported instead")]] + bool rasterOrderGroupsSupported() const; + + ReadWriteTextureTier readWriteTextureSupport() const; + + uint64_t recommendedMaxWorkingSetSize() const; + + uint64_t registryID() const; + + [[deprecated("please use isRemovable instead")]] + bool removable() const; + + void sampleTimestamps(MTL::Timestamp* cpuTimestamp, MTL::Timestamp* gpuTimestamp); + + void setShouldMaximizeConcurrentCompilation(bool shouldMaximizeConcurrentCompilation); + bool shouldMaximizeConcurrentCompilation() const; + + NS::UInteger sizeOfCounterHeapEntry(MTL4::CounterHeapType type); + + Size sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount); + Size sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount, MTL::SparsePageSize sparsePageSize); + NS::UInteger sparseTileSizeInBytes() const; + NS::UInteger sparseTileSizeInBytes(MTL::SparsePageSize sparsePageSize); + + bool supports32BitFloatFiltering() const; + + bool supports32BitMSAA() const; + + bool supportsBCTextureCompression() const; + + bool supportsCounterSampling(MTL::CounterSamplingPoint samplingPoint); + + bool supportsDynamicLibraries() const; + + bool supportsFamily(MTL::GPUFamily gpuFamily); + + bool supportsFeatureSet(MTL::FeatureSet featureSet); + + bool supportsFunctionPointers() const; + bool supportsFunctionPointersFromRender() const; + + bool supportsPrimitiveMotionBlur() const; + + bool supportsPullModelInterpolation() const; + + bool supportsQueryTextureLOD() const; + + bool supportsRasterizationRateMap(NS::UInteger layerCount); + + bool supportsRaytracing() const; + bool supportsRaytracingFromRender() const; + + bool supportsRenderDynamicLibraries() const; + + bool supportsShaderBarycentricCoordinates() const; + + bool supportsTextureSampleCount(NS::UInteger sampleCount); + + bool supportsVertexAmplificationCount(NS::UInteger count); + + SizeAndAlign tensorSizeAndAlign(const MTL::TensorDescriptor* descriptor); +}; + +} + +#if defined(MTL_PRIVATE_IMPLEMENTATION) +extern "C" MTL::Device* MTLCreateSystemDefaultDevice(); +extern "C" NS::Array* MTLCopyAllDevices(); +extern "C" NS::Array* MTLCopyAllDevicesWithObserver(NS::Object**, MTL::DeviceNotificationHandlerBlock); +extern "C" void MTLRemoveDeviceObserver(const NS::Object*); +_MTL_PRIVATE_DEF_WEAK_CONST(MTL::DeviceNotificationName, DeviceWasAddedNotification); +_MTL_PRIVATE_DEF_WEAK_CONST(MTL::DeviceNotificationName, DeviceRemovalRequestedNotification); +_MTL_PRIVATE_DEF_WEAK_CONST(MTL::DeviceNotificationName, DeviceWasRemovedNotification); +_MTL_PRIVATE_DEF_CONST(NS::ErrorUserInfoKey, CommandBufferEncoderInfoErrorKey); +_NS_EXPORT MTL::Device* MTL::CreateSystemDefaultDevice() +{ + return ::MTLCreateSystemDefaultDevice(); +} + +_NS_EXPORT NS::Array* MTL::CopyAllDevices() +{ +#if (__IPHONE_OS_VERSION_MIN_REQUIRED >= 180000) || (__MAC_OS_X_VERSION_MIN_REQUIRED >= 101100) + return ::MTLCopyAllDevices(); +#else + return nullptr; +#endif +} + +_NS_EXPORT NS::Array* MTL::CopyAllDevicesWithObserver(NS::Object** pOutObserver, MTL::DeviceNotificationHandlerBlock handler) +{ +#if TARGET_OS_OSX + return ::MTLCopyAllDevicesWithObserver(pOutObserver, handler); +#else + (void)pOutObserver; + (void)handler; + return nullptr; +#endif // TARGET_OS_OSX +} + +_NS_EXPORT NS::Array* MTL::CopyAllDevicesWithObserver(NS::Object** pOutObserver, const MTL::DeviceNotificationHandlerFunction& handler) +{ + __block DeviceNotificationHandlerFunction function = handler; + return CopyAllDevicesWithObserver(pOutObserver, ^(Device* pDevice, DeviceNotificationName pNotificationName) { function(pDevice, pNotificationName); }); +} + +_NS_EXPORT void MTL::RemoveDeviceObserver(const NS::Object* pObserver) +{ + (void)pObserver; +#if TARGET_OS_OSX + ::MTLRemoveDeviceObserver(pObserver); +#endif // TARGET_OS_OSX +} + +#endif // MTL_PRIVATE_IMPLEMENTATION + +_MTL_INLINE MTL::BindingAccess MTL::ArgumentDescriptor::access() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(access)); +} + +_MTL_INLINE MTL::ArgumentDescriptor* MTL::ArgumentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArgumentDescriptor)); +} + +_MTL_INLINE MTL::ArgumentDescriptor* MTL::ArgumentDescriptor::argumentDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLArgumentDescriptor), _MTL_PRIVATE_SEL(argumentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentDescriptor::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentDescriptor::constantBlockAlignment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantBlockAlignment)); +} + +_MTL_INLINE MTL::DataType MTL::ArgumentDescriptor::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE NS::UInteger MTL::ArgumentDescriptor::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE MTL::ArgumentDescriptor* MTL::ArgumentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setAccess(MTL::BindingAccess access) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAccess_), access); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setArrayLength(NS::UInteger arrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArrayLength_), arrayLength); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setConstantBlockAlignment(NS::UInteger constantBlockAlignment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantBlockAlignment_), constantBlockAlignment); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setDataType(MTL::DataType dataType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDataType_), dataType); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setIndex(NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndex_), index); +} + +_MTL_INLINE void MTL::ArgumentDescriptor::setTextureType(MTL::TextureType textureType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureType_), textureType); +} + +_MTL_INLINE MTL::TextureType MTL::ArgumentDescriptor::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::Architecture* MTL::Architecture::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLArchitecture)); +} + +_MTL_INLINE MTL::Architecture* MTL::Architecture::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::Architecture::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::AccelerationStructureSizes MTL::Device::accelerationStructureSizes(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(accelerationStructureSizesWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Architecture* MTL::Device::architecture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(architecture)); +} + +_MTL_INLINE bool MTL::Device::areBarycentricCoordsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areBarycentricCoordsSupported)); +} + +_MTL_INLINE bool MTL::Device::areProgrammableSamplePositionsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areProgrammableSamplePositionsSupported)); +} + +_MTL_INLINE bool MTL::Device::areRasterOrderGroupsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areRasterOrderGroupsSupported)); +} + +_MTL_INLINE MTL::ArgumentBuffersTier MTL::Device::argumentBuffersSupport() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentBuffersSupport)); +} + +_MTL_INLINE bool MTL::Device::barycentricCoordsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areBarycentricCoordsSupported)); +} + +_MTL_INLINE void MTL::Device::convertSparsePixelRegions(const MTL::Region* pixelRegions, MTL::Region* tileRegions, MTL::Size tileSize, MTL::SparseTextureRegionAlignmentMode mode, NS::UInteger numRegions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(convertSparsePixelRegions_toTileRegions_withTileSize_alignmentMode_numRegions_), pixelRegions, tileRegions, tileSize, mode, numRegions); +} + +_MTL_INLINE void MTL::Device::convertSparseTileRegions(const MTL::Region* tileRegions, MTL::Region* pixelRegions, MTL::Size tileSize, NS::UInteger numRegions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(convertSparseTileRegions_toPixelRegions_withTileSize_numRegions_), tileRegions, pixelRegions, tileSize, numRegions); +} + +_MTL_INLINE NS::Array* MTL::Device::counterSets() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(counterSets)); +} + +_MTL_INLINE NS::UInteger MTL::Device::currentAllocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(currentAllocatedSize)); +} + +_MTL_INLINE bool MTL::Device::depth24Stencil8PixelFormatSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(isDepth24Stencil8PixelFormatSupported)); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::Device::functionHandle(const MTL::Function* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithFunction_), function); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::Device::functionHandle(const MTL4::BinaryFunction* function) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithBinaryFunction_), function); +} + +_MTL_INLINE void MTL::Device::getDefaultSamplePositions(MTL::SamplePosition* positions, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getDefaultSamplePositions_count_), positions, count); +} + +_MTL_INLINE bool MTL::Device::hasUnifiedMemory() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hasUnifiedMemory)); +} + +_MTL_INLINE bool MTL::Device::headless() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isHeadless)); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapAccelerationStructureSizeAndAlign(NS::UInteger size) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapAccelerationStructureSizeAndAlignWithSize_), size); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapAccelerationStructureSizeAndAlign(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapAccelerationStructureSizeAndAlignWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapBufferSizeAndAlign(NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapBufferSizeAndAlignWithLength_options_), length, options); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::heapTextureSizeAndAlign(const MTL::TextureDescriptor* desc) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapTextureSizeAndAlignWithDescriptor_), desc); +} + +_MTL_INLINE bool MTL::Device::isDepth24Stencil8PixelFormatSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(isDepth24Stencil8PixelFormatSupported)); +} + +_MTL_INLINE bool MTL::Device::isHeadless() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isHeadless)); +} + +_MTL_INLINE bool MTL::Device::isLowPower() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isLowPower)); +} + +_MTL_INLINE bool MTL::Device::isRemovable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRemovable)); +} + +_MTL_INLINE MTL::DeviceLocation MTL::Device::location() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(location)); +} + +_MTL_INLINE NS::UInteger MTL::Device::locationNumber() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(locationNumber)); +} + +_MTL_INLINE bool MTL::Device::lowPower() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isLowPower)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maxArgumentBufferSamplerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxArgumentBufferSamplerCount)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maxBufferLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxBufferLength)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maxThreadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxThreadgroupMemoryLength)); +} + +_MTL_INLINE MTL::Size MTL::Device::maxThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxThreadsPerThreadgroup)); +} + +_MTL_INLINE uint64_t MTL::Device::maxTransferRate() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTransferRate)); +} + +_MTL_INLINE NS::UInteger MTL::Device::maximumConcurrentCompilationTaskCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maximumConcurrentCompilationTaskCount)); +} + +_MTL_INLINE NS::UInteger MTL::Device::minimumLinearTextureAlignmentForPixelFormat(MTL::PixelFormat format) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(minimumLinearTextureAlignmentForPixelFormat_), format); +} + +_MTL_INLINE NS::UInteger MTL::Device::minimumTextureBufferAlignmentForPixelFormat(MTL::PixelFormat format) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(minimumTextureBufferAlignmentForPixelFormat_), format); +} + +_MTL_INLINE NS::String* MTL::Device::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Device::newAccelerationStructure(NS::UInteger size) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithSize_), size); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Device::newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL4::Archive* MTL::Device::newArchive(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArchiveWithURL_error_), url, error); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Device::newArgumentEncoder(const NS::Array* arguments) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithArguments_), arguments); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Device::newArgumentEncoder(const MTL::BufferBinding* bufferBinding) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithBufferBinding_), bufferBinding); +} + +_MTL_INLINE MTL4::ArgumentTable* MTL::Device::newArgumentTable(const MTL4::ArgumentTableDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentTableWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::BinaryArchive* MTL::Device::newBinaryArchive(const MTL::BinaryArchiveDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBinaryArchiveWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_), length, options); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithBytes_length_options_), pointer, length, options); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(const void* pointer, NS::UInteger length, MTL::ResourceOptions options, void (^deallocator)(void*, NS::UInteger)) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithBytesNoCopy_length_options_deallocator_), pointer, length, options, deallocator); +} + +_MTL_INLINE MTL::Buffer* MTL::Device::newBuffer(NS::UInteger length, MTL::ResourceOptions options, MTL::SparsePageSize placementSparsePageSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_placementSparsePageSize_), length, options, placementSparsePageSize); +} + +_MTL_INLINE MTL4::CommandAllocator* MTL::Device::newCommandAllocator() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandAllocator)); +} + +_MTL_INLINE MTL4::CommandAllocator* MTL::Device::newCommandAllocator(const MTL4::CommandAllocatorDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandAllocatorWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CommandBuffer* MTL::Device::newCommandBuffer() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandBuffer)); +} + +_MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueue)); +} + +_MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue(NS::UInteger maxCommandBufferCount) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueueWithMaxCommandBufferCount_), maxCommandBufferCount); +} + +_MTL_INLINE MTL::CommandQueue* MTL::Device::newCommandQueue(const MTL::CommandQueueDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCommandQueueWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL4::Compiler* MTL::Device::newCompiler(const MTL4::CompilerDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCompilerWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_error_), computeFunction, error); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_options_reflection_error_), computeFunction, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, const MTL::NewComputePipelineStateCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_completionHandler_), computeFunction, completionHandler); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* computeFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithFunction_options_completionHandler_), computeFunction, options, completionHandler); +} + +_MTL_INLINE MTL::ComputePipelineState* MTL::Device::newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedComputePipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::ComputePipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* pFunction, const MTL::NewComputePipelineStateCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewComputePipelineStateCompletionHandlerFunction blockCompletionHandler = completionHandler; + newComputePipelineState(pFunction, ^(MTL::ComputePipelineState* pPipelineState, NS::Error* pError) { blockCompletionHandler(pPipelineState, pError); }); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::Function* pFunction, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newComputePipelineState(pFunction, options, ^(MTL::ComputePipelineState* pPipelineState, MTL::ComputePipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE void MTL::Device::newComputePipelineState(const MTL::ComputePipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewComputePipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block NewComputePipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newComputePipelineState(pDescriptor, options, ^(ComputePipelineState* pPipelineState, ComputePipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE MTL4::CounterHeap* MTL::Device::newCounterHeap(const MTL4::CounterHeapDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCounterHeapWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::Device::newCounterSampleBuffer(const MTL::CounterSampleBufferDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newCounterSampleBufferWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newDefaultLibrary() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDefaultLibrary)); +} + +_MTL_INLINE MTL::Library* MTL::Device::newDefaultLibrary(const NS::Bundle* bundle, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDefaultLibraryWithBundle_error_), bundle, error); +} + +_MTL_INLINE MTL::DepthStencilState* MTL::Device::newDepthStencilState(const MTL::DepthStencilDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDepthStencilStateWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL::Device::newDynamicLibrary(const MTL::Library* library, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibrary_error_), library, error); +} + +_MTL_INLINE MTL::DynamicLibrary* MTL::Device::newDynamicLibrary(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newDynamicLibraryWithURL_error_), url, error); +} + +_MTL_INLINE MTL::Event* MTL::Device::newEvent() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newEvent)); +} + +_MTL_INLINE MTL::Fence* MTL::Device::newFence() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFence)); +} + +_MTL_INLINE MTL::Heap* MTL::Device::newHeap(const MTL::HeapDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newHeapWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::IOCommandQueue* MTL::Device::newIOCommandQueue(const MTL::IOCommandQueueDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOCommandQueueWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOFileHandle(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOFileHandleWithURL_error_), url, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOFileHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOFileHandleWithURL_compressionMethod_error_), url, compressionMethod, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOHandle(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOHandleWithURL_error_), url, error); +} + +_MTL_INLINE MTL::IOFileHandle* MTL::Device::newIOHandle(const NS::URL* url, MTL::IOCompressionMethod compressionMethod, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIOHandleWithURL_compressionMethod_error_), url, compressionMethod, error); +} + +_MTL_INLINE MTL::IndirectCommandBuffer* MTL::Device::newIndirectCommandBuffer(const MTL::IndirectCommandBufferDescriptor* descriptor, NS::UInteger maxCount, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIndirectCommandBufferWithDescriptor_maxCommandCount_options_), descriptor, maxCount, options); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const NS::String* filepath, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithFile_error_), filepath, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithURL_error_), url, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const dispatch_data_t data, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithData_error_), data, error); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const NS::String* source, const MTL::CompileOptions* options, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithSource_options_error_), source, options, error); +} + +_MTL_INLINE void MTL::Device::newLibrary(const NS::String* source, const MTL::CompileOptions* options, const MTL::NewLibraryCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithSource_options_completionHandler_), source, options, completionHandler); +} + +_MTL_INLINE MTL::Library* MTL::Device::newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithStitchedDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::Device::newLibrary(const MTL::StitchedLibraryDescriptor* descriptor, const MTL::NewLibraryCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newLibraryWithStitchedDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE void MTL::Device::newLibrary(const NS::String* pSource, const MTL::CompileOptions* pOptions, const MTL::NewLibraryCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewLibraryCompletionHandlerFunction blockCompletionHandler = completionHandler; + newLibrary(pSource, pOptions, ^(MTL::Library* pLibrary, NS::Error* pError) { blockCompletionHandler(pLibrary, pError); }); +} + +_MTL_INLINE void MTL::Device::newLibrary(const MTL::StitchedLibraryDescriptor* pDescriptor, const MTL::NewLibraryCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewLibraryCompletionHandlerFunction blockCompletionHandler = completionHandler; + newLibrary(pDescriptor, ^(MTL::Library* pLibrary, NS::Error* pError) { blockCompletionHandler(pLibrary, pError); }); +} + +_MTL_INLINE MTL::LogState* MTL::Device::newLogState(const MTL::LogStateDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newLogStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::CommandQueue* MTL::Device::newMTL4CommandQueue() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMTL4CommandQueue)); +} + +_MTL_INLINE MTL4::CommandQueue* MTL::Device::newMTL4CommandQueue(const MTL4::CommandQueueDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newMTL4CommandQueueWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL4::PipelineDataSetSerializer* MTL::Device::newPipelineDataSetSerializer(const MTL4::PipelineDataSetSerializerDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newPipelineDataSetSerializerWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::RasterizationRateMap* MTL::Device::newRasterizationRateMap(const MTL::RasterizationRateMapDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRasterizationRateMapWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, const MTL::NewRenderPipelineStateCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithTileDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithTileDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::Device::newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::AutoreleasedRenderPipelineReflection* reflection, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithMeshDescriptor_options_reflection_error_), descriptor, options, reflection, error); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::MeshRenderPipelineDescriptor* descriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandler completionHandler) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithMeshDescriptor_options_completionHandler_), descriptor, options, completionHandler); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, const MTL::NewRenderPipelineStateCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewRenderPipelineStateCompletionHandlerFunction blockCompletionHandler = completionHandler; + newRenderPipelineState(pDescriptor, ^(MTL::RenderPipelineState* pPipelineState, NS::Error* pError) { blockCompletionHandler(pPipelineState, pError); }); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::RenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newRenderPipelineState(pDescriptor, options, ^(MTL::RenderPipelineState* pPipelineState, MTL::RenderPipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE void MTL::Device::newRenderPipelineState(const MTL::TileRenderPipelineDescriptor* pDescriptor, MTL::PipelineOption options, const MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction& completionHandler) +{ + __block MTL::NewRenderPipelineStateWithReflectionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newRenderPipelineState(pDescriptor, options, ^(MTL::RenderPipelineState* pPipelineState, MTL::RenderPipelineReflection* pReflection, NS::Error* pError) { blockCompletionHandler(pPipelineState, pReflection, pError); }); +} + +_MTL_INLINE MTL::ResidencySet* MTL::Device::newResidencySet(const MTL::ResidencySetDescriptor* desc, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newResidencySetWithDescriptor_error_), desc, error); +} + +_MTL_INLINE MTL::SamplerState* MTL::Device::newSamplerState(const MTL::SamplerDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSamplerStateWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::SharedEvent* MTL::Device::newSharedEvent() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedEvent)); +} + +_MTL_INLINE MTL::SharedEvent* MTL::Device::newSharedEvent(const MTL::SharedEventHandle* sharedEventHandle) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedEventWithHandle_), sharedEventHandle); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newSharedTexture(const MTL::TextureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedTextureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newSharedTexture(const MTL::SharedTextureHandle* sharedHandle) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedTextureWithHandle_), sharedHandle); +} + +_MTL_INLINE MTL::Tensor* MTL::Device::newTensor(const MTL::TensorDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTensorWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newTexture(const MTL::TextureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Device::newTexture(const MTL::TextureDescriptor* descriptor, const IOSurfaceRef iosurface, NS::UInteger plane) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_iosurface_plane_), descriptor, iosurface, plane); +} + +_MTL_INLINE MTL::TextureViewPool* MTL::Device::newTextureViewPool(const MTL::ResourceViewPoolDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewPoolWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE uint32_t MTL::Device::peerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(peerCount)); +} + +_MTL_INLINE uint64_t MTL::Device::peerGroupID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(peerGroupID)); +} + +_MTL_INLINE uint32_t MTL::Device::peerIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(peerIndex)); +} + +_MTL_INLINE bool MTL::Device::programmableSamplePositionsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areProgrammableSamplePositionsSupported)); +} + +_MTL_INLINE uint64_t MTL::Device::queryTimestampFrequency() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(queryTimestampFrequency)); +} + +_MTL_INLINE bool MTL::Device::rasterOrderGroupsSupported() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(areRasterOrderGroupsSupported)); +} + +_MTL_INLINE MTL::ReadWriteTextureTier MTL::Device::readWriteTextureSupport() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(readWriteTextureSupport)); +} + +_MTL_INLINE uint64_t MTL::Device::recommendedMaxWorkingSetSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(recommendedMaxWorkingSetSize)); +} + +_MTL_INLINE uint64_t MTL::Device::registryID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(registryID)); +} + +_MTL_INLINE bool MTL::Device::removable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRemovable)); +} + +_MTL_INLINE void MTL::Device::sampleTimestamps(MTL::Timestamp* cpuTimestamp, MTL::Timestamp* gpuTimestamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleTimestamps_gpuTimestamp_), cpuTimestamp, gpuTimestamp); +} + +_MTL_INLINE void MTL::Device::setShouldMaximizeConcurrentCompilation(bool shouldMaximizeConcurrentCompilation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShouldMaximizeConcurrentCompilation_), shouldMaximizeConcurrentCompilation); +} + +_MTL_INLINE bool MTL::Device::shouldMaximizeConcurrentCompilation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shouldMaximizeConcurrentCompilation)); +} + +_MTL_INLINE NS::UInteger MTL::Device::sizeOfCounterHeapEntry(MTL4::CounterHeapType type) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sizeOfCounterHeapEntry_), type); +} + +_MTL_INLINE MTL::Size MTL::Device::sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_), textureType, pixelFormat, sampleCount); +} + +_MTL_INLINE MTL::Size MTL::Device::sparseTileSize(MTL::TextureType textureType, MTL::PixelFormat pixelFormat, NS::UInteger sampleCount, MTL::SparsePageSize sparsePageSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_sparsePageSize_), textureType, pixelFormat, sampleCount, sparsePageSize); +} + +_MTL_INLINE NS::UInteger MTL::Device::sparseTileSizeInBytes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeInBytes)); +} + +_MTL_INLINE NS::UInteger MTL::Device::sparseTileSizeInBytes(MTL::SparsePageSize sparsePageSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTileSizeInBytesForSparsePageSize_), sparsePageSize); +} + +_MTL_INLINE bool MTL::Device::supports32BitFloatFiltering() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supports32BitFloatFiltering)); +} + +_MTL_INLINE bool MTL::Device::supports32BitMSAA() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supports32BitMSAA)); +} + +_MTL_INLINE bool MTL::Device::supportsBCTextureCompression() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsBCTextureCompression)); +} + +_MTL_INLINE bool MTL::Device::supportsCounterSampling(MTL::CounterSamplingPoint samplingPoint) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsCounterSampling_), samplingPoint); +} + +_MTL_INLINE bool MTL::Device::supportsDynamicLibraries() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsDynamicLibraries)); +} + +_MTL_INLINE bool MTL::Device::supportsFamily(MTL::GPUFamily gpuFamily) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFamily_), gpuFamily); +} + +_MTL_INLINE bool MTL::Device::supportsFeatureSet(MTL::FeatureSet featureSet) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFeatureSet_), featureSet); +} + +_MTL_INLINE bool MTL::Device::supportsFunctionPointers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFunctionPointers)); +} + +_MTL_INLINE bool MTL::Device::supportsFunctionPointersFromRender() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsFunctionPointersFromRender)); +} + +_MTL_INLINE bool MTL::Device::supportsPrimitiveMotionBlur() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsPrimitiveMotionBlur)); +} + +_MTL_INLINE bool MTL::Device::supportsPullModelInterpolation() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsPullModelInterpolation)); +} + +_MTL_INLINE bool MTL::Device::supportsQueryTextureLOD() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsQueryTextureLOD)); +} + +_MTL_INLINE bool MTL::Device::supportsRasterizationRateMap(NS::UInteger layerCount) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRasterizationRateMapWithLayerCount_), layerCount); +} + +_MTL_INLINE bool MTL::Device::supportsRaytracing() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRaytracing)); +} + +_MTL_INLINE bool MTL::Device::supportsRaytracingFromRender() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRaytracingFromRender)); +} + +_MTL_INLINE bool MTL::Device::supportsRenderDynamicLibraries() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsRenderDynamicLibraries)); +} + +_MTL_INLINE bool MTL::Device::supportsShaderBarycentricCoordinates() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsShaderBarycentricCoordinates)); +} + +_MTL_INLINE bool MTL::Device::supportsTextureSampleCount(NS::UInteger sampleCount) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsTextureSampleCount_), sampleCount); +} + +_MTL_INLINE bool MTL::Device::supportsVertexAmplificationCount(NS::UInteger count) +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportsVertexAmplificationCount_), count); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::Device::tensorSizeAndAlign(const MTL::TensorDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tensorSizeAndAlignWithDescriptor_), descriptor); +} diff --git a/dist/include/metal_cpp/Metal/MTLDrawable.hpp b/dist/include/metal_cpp/Metal/MTLDrawable.hpp new file mode 100644 index 0000000..fad4fed --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLDrawable.hpp @@ -0,0 +1,90 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDrawable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +#include +#include + +namespace MTL +{ +class Drawable; + +using DrawablePresentedHandler = void (^)(MTL::Drawable*); +using DrawablePresentedHandlerFunction = std::function; + +class Drawable : public NS::Referencing +{ +public: + void addPresentedHandler(const MTL::DrawablePresentedHandler block); + void addPresentedHandler(const MTL::DrawablePresentedHandlerFunction& function); + + NS::UInteger drawableID() const; + + void present(); + void presentAfterMinimumDuration(CFTimeInterval duration); + + void presentAtTime(CFTimeInterval presentationTime); + + CFTimeInterval presentedTime() const; +}; + +} +_MTL_INLINE void MTL::Drawable::addPresentedHandler(const MTL::DrawablePresentedHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addPresentedHandler_), block); +} + +_MTL_INLINE void MTL::Drawable::addPresentedHandler(const MTL::DrawablePresentedHandlerFunction& function) +{ + __block DrawablePresentedHandlerFunction blockFunction = function; + addPresentedHandler(^(Drawable* pDrawable) { blockFunction(pDrawable); }); +} + +_MTL_INLINE NS::UInteger MTL::Drawable::drawableID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(drawableID)); +} + +_MTL_INLINE void MTL::Drawable::present() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(present)); +} + +_MTL_INLINE void MTL::Drawable::presentAfterMinimumDuration(CFTimeInterval duration) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentAfterMinimumDuration_), duration); +} + +_MTL_INLINE void MTL::Drawable::presentAtTime(CFTimeInterval presentationTime) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(presentAtTime_), presentationTime); +} + +_MTL_INLINE CFTimeInterval MTL::Drawable::presentedTime() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(presentedTime)); +} diff --git a/dist/include/metal_cpp/Metal/MTLDynamicLibrary.hpp b/dist/include/metal_cpp/Metal/MTLDynamicLibrary.hpp new file mode 100644 index 0000000..0726acc --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLDynamicLibrary.hpp @@ -0,0 +1,78 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLDynamicLibrary.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Device; +_MTL_ENUM(NS::UInteger, DynamicLibraryError) { + DynamicLibraryErrorNone = 0, + DynamicLibraryErrorInvalidFile = 1, + DynamicLibraryErrorCompilationFailure = 2, + DynamicLibraryErrorUnresolvedInstallName = 3, + DynamicLibraryErrorDependencyLoadFailure = 4, + DynamicLibraryErrorUnsupported = 5, +}; + +class DynamicLibrary : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* installName() const; + + NS::String* label() const; + + bool serializeToURL(const NS::URL* url, NS::Error** error); + + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE MTL::Device* MTL::DynamicLibrary::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::DynamicLibrary::installName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(installName)); +} + +_MTL_INLINE NS::String* MTL::DynamicLibrary::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::DynamicLibrary::serializeToURL(const NS::URL* url, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(serializeToURL_error_), url, error); +} + +_MTL_INLINE void MTL::DynamicLibrary::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/dist/include/metal_cpp/Metal/MTLEvent.hpp b/dist/include/metal_cpp/Metal/MTLEvent.hpp new file mode 100644 index 0000000..d06b969 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLEvent.hpp @@ -0,0 +1,170 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLEvent.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include +#include + +#include +#include + +namespace MTL +{ +class Device; +class SharedEvent; +class SharedEventHandle; +class SharedEventListener; + +using SharedEventNotificationBlock = void (^)(SharedEvent* pEvent, std::uint64_t value); +using SharedEventNotificationFunction = std::function; + +class Event : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* label() const; + void setLabel(const NS::String* label); +}; +class SharedEventListener : public NS::Referencing +{ +public: + static SharedEventListener* alloc(); + + dispatch_queue_t dispatchQueue() const; + + SharedEventListener* init(); + SharedEventListener* init(const dispatch_queue_t dispatchQueue); + + static SharedEventListener* sharedListener(); +}; +class SharedEvent : public NS::Referencing +{ +public: + SharedEventHandle* newSharedEventHandle(); + + void notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationBlock block); + void notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationFunction& function); + + void setSignaledValue(uint64_t signaledValue); + uint64_t signaledValue() const; + bool waitUntilSignaledValue(uint64_t value, uint64_t milliseconds); +}; +class SharedEventHandle : public NS::SecureCoding +{ +public: + static SharedEventHandle* alloc(); + + SharedEventHandle* init(); + + NS::String* label() const; +}; + +} +_MTL_INLINE MTL::Device* MTL::Event::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::Event::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::Event::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSharedEventListener)); +} + +_MTL_INLINE dispatch_queue_t MTL::SharedEventListener::dispatchQueue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchQueue)); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::init(const dispatch_queue_t dispatchQueue) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithDispatchQueue_), dispatchQueue); +} + +_MTL_INLINE MTL::SharedEventListener* MTL::SharedEventListener::sharedListener() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLSharedEventListener), _MTL_PRIVATE_SEL(sharedListener)); +} + +_MTL_INLINE MTL::SharedEventHandle* MTL::SharedEvent::newSharedEventHandle() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedEventHandle)); +} + +_MTL_INLINE void MTL::SharedEvent::notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationBlock block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(notifyListener_atValue_block_), listener, value, block); +} + +_MTL_INLINE void MTL::SharedEvent::notifyListener(const MTL::SharedEventListener* listener, uint64_t value, const MTL::SharedEventNotificationFunction& function) +{ + __block MTL::SharedEventNotificationFunction callback = function; + notifyListener(listener, value, ^void(SharedEvent* pEvent, std::uint64_t innerValue) { callback(pEvent, innerValue); }); +} + +_MTL_INLINE void MTL::SharedEvent::setSignaledValue(uint64_t signaledValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSignaledValue_), signaledValue); +} + +_MTL_INLINE uint64_t MTL::SharedEvent::signaledValue() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(signaledValue)); +} + +_MTL_INLINE bool MTL::SharedEvent::waitUntilSignaledValue(uint64_t value, uint64_t milliseconds) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilSignaledValue_timeoutMS_), value, milliseconds); +} + +_MTL_INLINE MTL::SharedEventHandle* MTL::SharedEventHandle::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSharedEventHandle)); +} + +_MTL_INLINE MTL::SharedEventHandle* MTL::SharedEventHandle::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::SharedEventHandle::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} diff --git a/dist/include/metal_cpp/Metal/MTLFence.hpp b/dist/include/metal_cpp/Metal/MTLFence.hpp new file mode 100644 index 0000000..f31df4c --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLFence.hpp @@ -0,0 +1,55 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFence.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Device; + +class Fence : public NS::Referencing +{ +public: + Device* device() const; + + NS::String* label() const; + void setLabel(const NS::String* label); +}; + +} +_MTL_INLINE MTL::Device* MTL::Fence::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::Fence::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::Fence::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/dist/include/metal_cpp/Metal/MTLFunctionConstantValues.hpp b/dist/include/metal_cpp/Metal/MTLFunctionConstantValues.hpp new file mode 100644 index 0000000..dce89d1 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLFunctionConstantValues.hpp @@ -0,0 +1,76 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionConstantValues.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class FunctionConstantValues; + +class FunctionConstantValues : public NS::Copying +{ +public: + static FunctionConstantValues* alloc(); + + FunctionConstantValues* init(); + + void reset(); + + void setConstantValue(const void* value, MTL::DataType type, NS::UInteger index); + void setConstantValue(const void* value, MTL::DataType type, const NS::String* name); + void setConstantValues(const void* values, MTL::DataType type, NS::Range range); +}; + +} +_MTL_INLINE MTL::FunctionConstantValues* MTL::FunctionConstantValues::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionConstantValues)); +} + +_MTL_INLINE MTL::FunctionConstantValues* MTL::FunctionConstantValues::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::FunctionConstantValues::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::FunctionConstantValues::setConstantValue(const void* value, MTL::DataType type, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValue_type_atIndex_), value, type, index); +} + +_MTL_INLINE void MTL::FunctionConstantValues::setConstantValue(const void* value, MTL::DataType type, const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValue_type_withName_), value, type, name); +} + +_MTL_INLINE void MTL::FunctionConstantValues::setConstantValues(const void* values, MTL::DataType type, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValues_type_withRange_), values, type, range); +} diff --git a/dist/include/metal_cpp/Metal/MTLFunctionDescriptor.hpp b/dist/include/metal_cpp/Metal/MTLFunctionDescriptor.hpp new file mode 100644 index 0000000..aa296b5 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLFunctionDescriptor.hpp @@ -0,0 +1,153 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class FunctionConstantValues; +class FunctionDescriptor; +class IntersectionFunctionDescriptor; + +_MTL_OPTIONS(NS::UInteger, FunctionOptions) { + FunctionOptionNone = 0, + FunctionOptionCompileToBinary = 1, + FunctionOptionStoreFunctionInMetalPipelinesScript = 1 << 1, + FunctionOptionStoreFunctionInMetalScript = 1 << 1, + FunctionOptionFailOnBinaryArchiveMiss = 1 << 2, + FunctionOptionPipelineIndependent = 1 << 3, +}; + +class FunctionDescriptor : public NS::Copying +{ +public: + static FunctionDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + FunctionConstantValues* constantValues() const; + + static FunctionDescriptor* functionDescriptor(); + + FunctionDescriptor* init(); + + NS::String* name() const; + + FunctionOptions options() const; + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setConstantValues(const MTL::FunctionConstantValues* constantValues); + + void setName(const NS::String* name); + + void setOptions(MTL::FunctionOptions options); + + void setSpecializedName(const NS::String* specializedName); + NS::String* specializedName() const; +}; +class IntersectionFunctionDescriptor : public NS::Copying +{ +public: + static IntersectionFunctionDescriptor* alloc(); + + IntersectionFunctionDescriptor* init(); +}; + +} +_MTL_INLINE MTL::FunctionDescriptor* MTL::FunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::FunctionDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::FunctionConstantValues* MTL::FunctionDescriptor::constantValues() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(constantValues)); +} + +_MTL_INLINE MTL::FunctionDescriptor* MTL::FunctionDescriptor::functionDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLFunctionDescriptor), _MTL_PRIVATE_SEL(functionDescriptor)); +} + +_MTL_INLINE MTL::FunctionDescriptor* MTL::FunctionDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::FunctionDescriptor::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::FunctionOptions MTL::FunctionDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setConstantValues(const MTL::FunctionConstantValues* constantValues) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setConstantValues_), constantValues); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setOptions(MTL::FunctionOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} + +_MTL_INLINE void MTL::FunctionDescriptor::setSpecializedName(const NS::String* specializedName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSpecializedName_), specializedName); +} + +_MTL_INLINE NS::String* MTL::FunctionDescriptor::specializedName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(specializedName)); +} + +_MTL_INLINE MTL::IntersectionFunctionDescriptor* MTL::IntersectionFunctionDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIntersectionFunctionDescriptor)); +} + +_MTL_INLINE MTL::IntersectionFunctionDescriptor* MTL::IntersectionFunctionDescriptor::init() +{ + return NS::Object::init(); +} diff --git a/dist/include/metal_cpp/Metal/MTLFunctionHandle.hpp b/dist/include/metal_cpp/Metal/MTLFunctionHandle.hpp new file mode 100644 index 0000000..7a3ff95 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLFunctionHandle.hpp @@ -0,0 +1,65 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionHandle.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLLibrary.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; + +class FunctionHandle : public NS::Referencing +{ +public: + Device* device() const; + + FunctionType functionType() const; + + ResourceID gpuResourceID() const; + + NS::String* name() const; +}; + +} +_MTL_INLINE MTL::Device* MTL::FunctionHandle::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::FunctionType MTL::FunctionHandle::functionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionType)); +} + +_MTL_INLINE MTL::ResourceID MTL::FunctionHandle::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::String* MTL::FunctionHandle::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} diff --git a/dist/include/metal_cpp/Metal/MTLFunctionLog.hpp b/dist/include/metal_cpp/Metal/MTLFunctionLog.hpp new file mode 100644 index 0000000..454e605 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLFunctionLog.hpp @@ -0,0 +1,101 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionLog.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Function; +class FunctionLogDebugLocation; +_MTL_ENUM(NS::UInteger, FunctionLogType) { + FunctionLogTypeValidation = 0, +}; + +class LogContainer : public NS::Referencing +{ +}; +class FunctionLogDebugLocation : public NS::Referencing +{ +public: + NS::URL* URL() const; + + NS::UInteger column() const; + + NS::String* functionName() const; + + NS::UInteger line() const; +}; +class FunctionLog : public NS::Referencing +{ +public: + FunctionLogDebugLocation* debugLocation() const; + + NS::String* encoderLabel() const; + + Function* function() const; + + FunctionLogType type() const; +}; + +} +_MTL_INLINE NS::URL* MTL::FunctionLogDebugLocation::URL() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(URL)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionLogDebugLocation::column() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(column)); +} + +_MTL_INLINE NS::String* MTL::FunctionLogDebugLocation::functionName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionName)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionLogDebugLocation::line() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(line)); +} + +_MTL_INLINE MTL::FunctionLogDebugLocation* MTL::FunctionLog::debugLocation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(debugLocation)); +} + +_MTL_INLINE NS::String* MTL::FunctionLog::encoderLabel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(encoderLabel)); +} + +_MTL_INLINE MTL::Function* MTL::FunctionLog::function() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(function)); +} + +_MTL_INLINE MTL::FunctionLogType MTL::FunctionLog::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} diff --git a/dist/include/metal_cpp/Metal/MTLFunctionStitching.hpp b/dist/include/metal_cpp/Metal/MTLFunctionStitching.hpp new file mode 100644 index 0000000..8dd5fd2 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLFunctionStitching.hpp @@ -0,0 +1,319 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLFunctionStitching.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class FunctionStitchingAttributeAlwaysInline; +class FunctionStitchingFunctionNode; +class FunctionStitchingGraph; +class FunctionStitchingInputNode; +class StitchedLibraryDescriptor; + +_MTL_OPTIONS(NS::UInteger, StitchedLibraryOptions) { + StitchedLibraryOptionNone = 0, + StitchedLibraryOptionFailOnBinaryArchiveMiss = 1, + StitchedLibraryOptionStoreLibraryInMetalPipelinesScript = 1 << 1, +}; + +class FunctionStitchingAttribute : public NS::Referencing +{ +}; +class FunctionStitchingAttributeAlwaysInline : public NS::Referencing +{ +public: + static FunctionStitchingAttributeAlwaysInline* alloc(); + + FunctionStitchingAttributeAlwaysInline* init(); +}; +class FunctionStitchingNode : public NS::Copying +{ +}; +class FunctionStitchingInputNode : public NS::Referencing +{ +public: + static FunctionStitchingInputNode* alloc(); + + NS::UInteger argumentIndex() const; + + FunctionStitchingInputNode* init(); + FunctionStitchingInputNode* init(NS::UInteger argument); + + void setArgumentIndex(NS::UInteger argumentIndex); +}; +class FunctionStitchingFunctionNode : public NS::Referencing +{ +public: + static FunctionStitchingFunctionNode* alloc(); + + NS::Array* arguments() const; + + NS::Array* controlDependencies() const; + + FunctionStitchingFunctionNode* init(); + FunctionStitchingFunctionNode* init(const NS::String* name, const NS::Array* arguments, const NS::Array* controlDependencies); + + NS::String* name() const; + + void setArguments(const NS::Array* arguments); + + void setControlDependencies(const NS::Array* controlDependencies); + + void setName(const NS::String* name); +}; +class FunctionStitchingGraph : public NS::Copying +{ +public: + static FunctionStitchingGraph* alloc(); + + NS::Array* attributes() const; + + NS::String* functionName() const; + + FunctionStitchingGraph* init(); + FunctionStitchingGraph* init(const NS::String* functionName, const NS::Array* nodes, const MTL::FunctionStitchingFunctionNode* outputNode, const NS::Array* attributes); + + NS::Array* nodes() const; + + FunctionStitchingFunctionNode* outputNode() const; + + void setAttributes(const NS::Array* attributes); + + void setFunctionName(const NS::String* functionName); + + void setNodes(const NS::Array* nodes); + + void setOutputNode(const MTL::FunctionStitchingFunctionNode* outputNode); +}; +class StitchedLibraryDescriptor : public NS::Copying +{ +public: + static StitchedLibraryDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + NS::Array* functionGraphs() const; + + NS::Array* functions() const; + + StitchedLibraryDescriptor* init(); + + StitchedLibraryOptions options() const; + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setFunctionGraphs(const NS::Array* functionGraphs); + + void setFunctions(const NS::Array* functions); + + void setOptions(MTL::StitchedLibraryOptions options); +}; + +} +_MTL_INLINE MTL::FunctionStitchingAttributeAlwaysInline* MTL::FunctionStitchingAttributeAlwaysInline::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingAttributeAlwaysInline)); +} + +_MTL_INLINE MTL::FunctionStitchingAttributeAlwaysInline* MTL::FunctionStitchingAttributeAlwaysInline::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingInputNode* MTL::FunctionStitchingInputNode::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingInputNode)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionStitchingInputNode::argumentIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(argumentIndex)); +} + +_MTL_INLINE MTL::FunctionStitchingInputNode* MTL::FunctionStitchingInputNode::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingInputNode* MTL::FunctionStitchingInputNode::init(NS::UInteger argument) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithArgumentIndex_), argument); +} + +_MTL_INLINE void MTL::FunctionStitchingInputNode::setArgumentIndex(NS::UInteger argumentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArgumentIndex_), argumentIndex); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingFunctionNode::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingFunctionNode)); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingFunctionNode::arguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arguments)); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingFunctionNode::controlDependencies() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(controlDependencies)); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingFunctionNode::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingFunctionNode::init(const NS::String* name, const NS::Array* arguments, const NS::Array* controlDependencies) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithName_arguments_controlDependencies_), name, arguments, controlDependencies); +} + +_MTL_INLINE NS::String* MTL::FunctionStitchingFunctionNode::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE void MTL::FunctionStitchingFunctionNode::setArguments(const NS::Array* arguments) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArguments_), arguments); +} + +_MTL_INLINE void MTL::FunctionStitchingFunctionNode::setControlDependencies(const NS::Array* controlDependencies) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setControlDependencies_), controlDependencies); +} + +_MTL_INLINE void MTL::FunctionStitchingFunctionNode::setName(const NS::String* name) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setName_), name); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL::FunctionStitchingGraph::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionStitchingGraph)); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingGraph::attributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributes)); +} + +_MTL_INLINE NS::String* MTL::FunctionStitchingGraph::functionName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionName)); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL::FunctionStitchingGraph::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::FunctionStitchingGraph* MTL::FunctionStitchingGraph::init(const NS::String* functionName, const NS::Array* nodes, const MTL::FunctionStitchingFunctionNode* outputNode, const NS::Array* attributes) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithFunctionName_nodes_outputNode_attributes_), functionName, nodes, outputNode, attributes); +} + +_MTL_INLINE NS::Array* MTL::FunctionStitchingGraph::nodes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(nodes)); +} + +_MTL_INLINE MTL::FunctionStitchingFunctionNode* MTL::FunctionStitchingGraph::outputNode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(outputNode)); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setAttributes(const NS::Array* attributes) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAttributes_), attributes); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setFunctionName(const NS::String* functionName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionName_), functionName); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setNodes(const NS::Array* nodes) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setNodes_), nodes); +} + +_MTL_INLINE void MTL::FunctionStitchingGraph::setOutputNode(const MTL::FunctionStitchingFunctionNode* outputNode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOutputNode_), outputNode); +} + +_MTL_INLINE MTL::StitchedLibraryDescriptor* MTL::StitchedLibraryDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStitchedLibraryDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::StitchedLibraryDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE NS::Array* MTL::StitchedLibraryDescriptor::functionGraphs() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionGraphs)); +} + +_MTL_INLINE NS::Array* MTL::StitchedLibraryDescriptor::functions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functions)); +} + +_MTL_INLINE MTL::StitchedLibraryDescriptor* MTL::StitchedLibraryDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::StitchedLibraryOptions MTL::StitchedLibraryDescriptor::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setFunctionGraphs(const NS::Array* functionGraphs) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionGraphs_), functionGraphs); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setFunctions(const NS::Array* functions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_), functions); +} + +_MTL_INLINE void MTL::StitchedLibraryDescriptor::setOptions(MTL::StitchedLibraryOptions options) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptions_), options); +} diff --git a/dist/include/metal_cpp/Metal/MTLGPUAddress.hpp b/dist/include/metal_cpp/Metal/MTLGPUAddress.hpp new file mode 100644 index 0000000..fb9d61d --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLGPUAddress.hpp @@ -0,0 +1,36 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLGPUAddress.hpp +// +// Copyright 2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#ifdef __METAL_VERSION__ + +#include + +#else + +#include + +#endif // __METAL_VERSION__ + +namespace MTL +{ + using GPUAddress = uint64_t; +} diff --git a/dist/include/metal_cpp/Metal/MTLHeaderBridge.hpp b/dist/include/metal_cpp/Metal/MTLHeaderBridge.hpp new file mode 100644 index 0000000..6a3a142 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLHeaderBridge.hpp @@ -0,0 +1,3120 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLHeaderBridge.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once +#include "MTLPrivate.hpp" + +namespace MTL::Private::Class +{ + +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureMotionBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureMotionCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureMotionTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4AccelerationStructureTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4ArgumentTableDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4BinaryFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CommandAllocatorDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CommandBufferOptions); +_MTL_PRIVATE_DEF_CLS(MTL4CommandQueueDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CommitOptions); +_MTL_PRIVATE_DEF_CLS(MTL4CompilerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CompilerTaskOptions); +_MTL_PRIVATE_DEF_CLS(MTL4ComputePipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4CounterHeapDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4FunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4IndirectInstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4InstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4LibraryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4LibraryFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4MachineLearningPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4MachineLearningPipelineReflection); +_MTL_PRIVATE_DEF_CLS(MTL4MeshRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineDataSetSerializerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineOptions); +_MTL_PRIVATE_DEF_CLS(MTL4PipelineStageDynamicLinkingDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4PrimitiveAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineBinaryFunctionsDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4RenderPipelineDynamicLinkingDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4SpecializedFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4StaticLinkingDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4StitchedFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTL4TileRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureMotionBoundingBoxGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureMotionCurveGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureMotionTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructurePassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructurePassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLAccelerationStructureTriangleGeometryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLArchitecture); +_MTL_PRIVATE_DEF_CLS(MTLArgument); +_MTL_PRIVATE_DEF_CLS(MTLArgumentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLArrayType); +_MTL_PRIVATE_DEF_CLS(MTLAttribute); +_MTL_PRIVATE_DEF_CLS(MTLAttributeDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLAttributeDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLBinaryArchiveDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBlitPassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBlitPassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBlitPassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLBufferLayoutDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLBufferLayoutDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLCaptureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLCaptureManager); +_MTL_PRIVATE_DEF_CLS(MTLCommandBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLCommandQueueDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLCompileOptions); +_MTL_PRIVATE_DEF_CLS(MTLComputePassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLComputePassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLComputePassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLComputePipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLComputePipelineReflection); +_MTL_PRIVATE_DEF_CLS(MTLCounterSampleBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLDepthStencilDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLFunctionConstant); +_MTL_PRIVATE_DEF_CLS(MTLFunctionConstantValues); +_MTL_PRIVATE_DEF_CLS(MTLFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLFunctionReflection); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingAttributeAlwaysInline); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingFunctionNode); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingGraph); +_MTL_PRIVATE_DEF_CLS(MTLFunctionStitchingInputNode); +_MTL_PRIVATE_DEF_CLS(MTLHeapDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIOCommandQueueDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIndirectCommandBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIndirectInstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLInstanceAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIntersectionFunctionDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLIntersectionFunctionTableDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLLinkedFunctions); +_MTL_PRIVATE_DEF_CLS(MTLLogStateDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLLogicalToPhysicalColorAttachmentMap); +_MTL_PRIVATE_DEF_CLS(MTLMeshRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLMotionKeyframeData); +_MTL_PRIVATE_DEF_CLS(MTLPipelineBufferDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLPipelineBufferDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLPointerType); +_MTL_PRIVATE_DEF_CLS(MTLPrimitiveAccelerationStructureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateLayerArray); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateLayerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateMapDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRasterizationRateSampleArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassDepthAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPassStencilAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineFunctionsDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLRenderPipelineReflection); +_MTL_PRIVATE_DEF_CLS(MTLResidencySetDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLResourceStatePassDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLResourceViewPoolDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLSamplerDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLSharedEventHandle); +_MTL_PRIVATE_DEF_CLS(MTLSharedEventListener); +_MTL_PRIVATE_DEF_CLS(MTLSharedTextureHandle); +_MTL_PRIVATE_DEF_CLS(MTLStageInputOutputDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLStencilDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLStitchedLibraryDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLStructMember); +_MTL_PRIVATE_DEF_CLS(MTLStructType); +_MTL_PRIVATE_DEF_CLS(MTLTensorDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTensorExtents); +_MTL_PRIVATE_DEF_CLS(MTLTensorReferenceType); +_MTL_PRIVATE_DEF_CLS(MTLTextureDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTextureReferenceType); +_MTL_PRIVATE_DEF_CLS(MTLTextureViewDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTileRenderPipelineColorAttachmentDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLTileRenderPipelineColorAttachmentDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLTileRenderPipelineDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLType); +_MTL_PRIVATE_DEF_CLS(MTLVertexAttribute); +_MTL_PRIVATE_DEF_CLS(MTLVertexAttributeDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLVertexAttributeDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLVertexBufferLayoutDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLVertexBufferLayoutDescriptorArray); +_MTL_PRIVATE_DEF_CLS(MTLVertexDescriptor); +_MTL_PRIVATE_DEF_CLS(MTLVisibleFunctionTableDescriptor); + +} + +namespace MTL::Private::Protocol +{ + +_MTL_PRIVATE_DEF_PRO(MTL4Archive); +_MTL_PRIVATE_DEF_PRO(MTL4ArgumentTable); +_MTL_PRIVATE_DEF_PRO(MTL4BinaryFunction); +_MTL_PRIVATE_DEF_PRO(MTL4CommandAllocator); +_MTL_PRIVATE_DEF_PRO(MTL4CommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTL4CommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTL4CommandQueue); +_MTL_PRIVATE_DEF_PRO(MTL4CommitFeedback); +_MTL_PRIVATE_DEF_PRO(MTL4Compiler); +_MTL_PRIVATE_DEF_PRO(MTL4CompilerTask); +_MTL_PRIVATE_DEF_PRO(MTL4ComputeCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTL4CounterHeap); +_MTL_PRIVATE_DEF_PRO(MTL4MachineLearningCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTL4MachineLearningPipelineState); +_MTL_PRIVATE_DEF_PRO(MTL4PipelineDataSetSerializer); +_MTL_PRIVATE_DEF_PRO(MTL4RenderCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLAccelerationStructure); +_MTL_PRIVATE_DEF_PRO(MTLAccelerationStructureCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLAllocation); +_MTL_PRIVATE_DEF_PRO(MTLArgumentEncoder); +_MTL_PRIVATE_DEF_PRO(MTLBinaryArchive); +_MTL_PRIVATE_DEF_PRO(MTLBinding); +_MTL_PRIVATE_DEF_PRO(MTLBlitCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLBuffer); +_MTL_PRIVATE_DEF_PRO(MTLBufferBinding); +_MTL_PRIVATE_DEF_PRO(MTLCommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTLCommandBufferEncoderInfo); +_MTL_PRIVATE_DEF_PRO(MTLCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLCommandQueue); +_MTL_PRIVATE_DEF_PRO(MTLComputeCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLComputePipelineState); +_MTL_PRIVATE_DEF_PRO(MTLCounter); +_MTL_PRIVATE_DEF_PRO(MTLCounterSampleBuffer); +_MTL_PRIVATE_DEF_PRO(MTLCounterSet); +_MTL_PRIVATE_DEF_PRO(MTLDepthStencilState); +_MTL_PRIVATE_DEF_PRO(MTLDevice); +_MTL_PRIVATE_DEF_PRO(MTLDrawable); +_MTL_PRIVATE_DEF_PRO(MTLDynamicLibrary); +_MTL_PRIVATE_DEF_PRO(MTLEvent); +_MTL_PRIVATE_DEF_PRO(MTLFence); +_MTL_PRIVATE_DEF_PRO(MTLFunction); +_MTL_PRIVATE_DEF_PRO(MTLFunctionHandle); +_MTL_PRIVATE_DEF_PRO(MTLFunctionLog); +_MTL_PRIVATE_DEF_PRO(MTLFunctionLogDebugLocation); +_MTL_PRIVATE_DEF_PRO(MTLFunctionStitchingAttribute); +_MTL_PRIVATE_DEF_PRO(MTLFunctionStitchingNode); +_MTL_PRIVATE_DEF_PRO(MTLHeap); +_MTL_PRIVATE_DEF_PRO(MTLIOCommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTLIOCommandQueue); +_MTL_PRIVATE_DEF_PRO(MTLIOFileHandle); +_MTL_PRIVATE_DEF_PRO(MTLIOScratchBuffer); +_MTL_PRIVATE_DEF_PRO(MTLIOScratchBufferAllocator); +_MTL_PRIVATE_DEF_PRO(MTLIndirectCommandBuffer); +_MTL_PRIVATE_DEF_PRO(MTLIndirectComputeCommand); +_MTL_PRIVATE_DEF_PRO(MTLIndirectRenderCommand); +_MTL_PRIVATE_DEF_PRO(MTLIntersectionFunctionTable); +_MTL_PRIVATE_DEF_PRO(MTLLibrary); +_MTL_PRIVATE_DEF_PRO(MTLLogContainer); +_MTL_PRIVATE_DEF_PRO(MTLLogState); +_MTL_PRIVATE_DEF_PRO(MTLObjectPayloadBinding); +_MTL_PRIVATE_DEF_PRO(MTLParallelRenderCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLRasterizationRateMap); +_MTL_PRIVATE_DEF_PRO(MTLRenderCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLRenderPipelineState); +_MTL_PRIVATE_DEF_PRO(MTLResidencySet); +_MTL_PRIVATE_DEF_PRO(MTLResource); +_MTL_PRIVATE_DEF_PRO(MTLResourceStateCommandEncoder); +_MTL_PRIVATE_DEF_PRO(MTLResourceViewPool); +_MTL_PRIVATE_DEF_PRO(MTLSamplerState); +_MTL_PRIVATE_DEF_PRO(MTLSharedEvent); +_MTL_PRIVATE_DEF_PRO(MTLTensor); +_MTL_PRIVATE_DEF_PRO(MTLTensorBinding); +_MTL_PRIVATE_DEF_PRO(MTLTexture); +_MTL_PRIVATE_DEF_PRO(MTLTextureBinding); +_MTL_PRIVATE_DEF_PRO(MTLTextureViewPool); +_MTL_PRIVATE_DEF_PRO(MTLThreadgroupBinding); +_MTL_PRIVATE_DEF_PRO(MTLVisibleFunctionTable); + +} + +namespace MTL::Private::Selector +{ + +_MTL_PRIVATE_DEF_SEL(GPUEndTime, + "GPUEndTime"); +_MTL_PRIVATE_DEF_SEL(GPUStartTime, + "GPUStartTime"); +_MTL_PRIVATE_DEF_SEL(URL, + "URL"); +_MTL_PRIVATE_DEF_SEL(accelerationStructureCommandEncoder, + "accelerationStructureCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(accelerationStructureCommandEncoderWithDescriptor_, + "accelerationStructureCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(accelerationStructurePassDescriptor, + "accelerationStructurePassDescriptor"); +_MTL_PRIVATE_DEF_SEL(accelerationStructureSizesWithDescriptor_, + "accelerationStructureSizesWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(access, + "access"); +_MTL_PRIVATE_DEF_SEL(addAllocation_, + "addAllocation:"); +_MTL_PRIVATE_DEF_SEL(addAllocations_count_, + "addAllocations:count:"); +_MTL_PRIVATE_DEF_SEL(addBarrier, + "addBarrier"); +_MTL_PRIVATE_DEF_SEL(addCompletedHandler_, + "addCompletedHandler:"); +_MTL_PRIVATE_DEF_SEL(addComputePipelineFunctionsWithDescriptor_error_, + "addComputePipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addDebugMarker_range_, + "addDebugMarker:range:"); +_MTL_PRIVATE_DEF_SEL(addFeedbackHandler_, + "addFeedbackHandler:"); +_MTL_PRIVATE_DEF_SEL(addFunctionWithDescriptor_library_error_, + "addFunctionWithDescriptor:library:error:"); +_MTL_PRIVATE_DEF_SEL(addLibraryWithDescriptor_error_, + "addLibraryWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addLogHandler_, + "addLogHandler:"); +_MTL_PRIVATE_DEF_SEL(addMeshRenderPipelineFunctionsWithDescriptor_error_, + "addMeshRenderPipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addPresentedHandler_, + "addPresentedHandler:"); +_MTL_PRIVATE_DEF_SEL(addRenderPipelineFunctionsWithDescriptor_error_, + "addRenderPipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(addResidencySet_, + "addResidencySet:"); +_MTL_PRIVATE_DEF_SEL(addResidencySets_count_, + "addResidencySets:count:"); +_MTL_PRIVATE_DEF_SEL(addScheduledHandler_, + "addScheduledHandler:"); +_MTL_PRIVATE_DEF_SEL(addTileRenderPipelineFunctionsWithDescriptor_error_, + "addTileRenderPipelineFunctionsWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(alignment, + "alignment"); +_MTL_PRIVATE_DEF_SEL(allAllocations, + "allAllocations"); +_MTL_PRIVATE_DEF_SEL(allocatedSize, + "allocatedSize"); +_MTL_PRIVATE_DEF_SEL(allocationCount, + "allocationCount"); +_MTL_PRIVATE_DEF_SEL(allowDuplicateIntersectionFunctionInvocation, + "allowDuplicateIntersectionFunctionInvocation"); +_MTL_PRIVATE_DEF_SEL(allowGPUOptimizedContents, + "allowGPUOptimizedContents"); +_MTL_PRIVATE_DEF_SEL(allowReferencingUndefinedSymbols, + "allowReferencingUndefinedSymbols"); +_MTL_PRIVATE_DEF_SEL(alphaBlendOperation, + "alphaBlendOperation"); +_MTL_PRIVATE_DEF_SEL(alphaToCoverageState, + "alphaToCoverageState"); +_MTL_PRIVATE_DEF_SEL(alphaToOneState, + "alphaToOneState"); +_MTL_PRIVATE_DEF_SEL(architecture, + "architecture"); +_MTL_PRIVATE_DEF_SEL(areBarycentricCoordsSupported, + "areBarycentricCoordsSupported"); +_MTL_PRIVATE_DEF_SEL(areProgrammableSamplePositionsSupported, + "areProgrammableSamplePositionsSupported"); +_MTL_PRIVATE_DEF_SEL(areRasterOrderGroupsSupported, + "areRasterOrderGroupsSupported"); +_MTL_PRIVATE_DEF_SEL(argumentBuffersSupport, + "argumentBuffersSupport"); +_MTL_PRIVATE_DEF_SEL(argumentDescriptor, + "argumentDescriptor"); +_MTL_PRIVATE_DEF_SEL(argumentIndex, + "argumentIndex"); +_MTL_PRIVATE_DEF_SEL(argumentIndexStride, + "argumentIndexStride"); +_MTL_PRIVATE_DEF_SEL(arguments, + "arguments"); +_MTL_PRIVATE_DEF_SEL(arrayLength, + "arrayLength"); +_MTL_PRIVATE_DEF_SEL(arrayType, + "arrayType"); +_MTL_PRIVATE_DEF_SEL(attributeIndex, + "attributeIndex"); +_MTL_PRIVATE_DEF_SEL(attributeType, + "attributeType"); +_MTL_PRIVATE_DEF_SEL(attributes, + "attributes"); +_MTL_PRIVATE_DEF_SEL(backFaceStencil, + "backFaceStencil"); +_MTL_PRIVATE_DEF_SEL(barrierAfterEncoderStages_beforeEncoderStages_visibilityOptions_, + "barrierAfterEncoderStages:beforeEncoderStages:visibilityOptions:"); +_MTL_PRIVATE_DEF_SEL(barrierAfterQueueStages_beforeStages_, + "barrierAfterQueueStages:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(barrierAfterQueueStages_beforeStages_visibilityOptions_, + "barrierAfterQueueStages:beforeStages:visibilityOptions:"); +_MTL_PRIVATE_DEF_SEL(barrierAfterStages_beforeQueueStages_visibilityOptions_, + "barrierAfterStages:beforeQueueStages:visibilityOptions:"); +_MTL_PRIVATE_DEF_SEL(baseResourceID, + "baseResourceID"); +_MTL_PRIVATE_DEF_SEL(beginCommandBufferWithAllocator_, + "beginCommandBufferWithAllocator:"); +_MTL_PRIVATE_DEF_SEL(beginCommandBufferWithAllocator_options_, + "beginCommandBufferWithAllocator:options:"); +_MTL_PRIVATE_DEF_SEL(binaryArchives, + "binaryArchives"); +_MTL_PRIVATE_DEF_SEL(binaryFunctions, + "binaryFunctions"); +_MTL_PRIVATE_DEF_SEL(binaryLinkedFunctions, + "binaryLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(bindings, + "bindings"); +_MTL_PRIVATE_DEF_SEL(blendingState, + "blendingState"); +_MTL_PRIVATE_DEF_SEL(blitCommandEncoder, + "blitCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(blitCommandEncoderWithDescriptor_, + "blitCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(blitPassDescriptor, + "blitPassDescriptor"); +_MTL_PRIVATE_DEF_SEL(borderColor, + "borderColor"); +_MTL_PRIVATE_DEF_SEL(boundingBoxBuffer, + "boundingBoxBuffer"); +_MTL_PRIVATE_DEF_SEL(boundingBoxBufferOffset, + "boundingBoxBufferOffset"); +_MTL_PRIVATE_DEF_SEL(boundingBoxBuffers, + "boundingBoxBuffers"); +_MTL_PRIVATE_DEF_SEL(boundingBoxCount, + "boundingBoxCount"); +_MTL_PRIVATE_DEF_SEL(boundingBoxStride, + "boundingBoxStride"); +_MTL_PRIVATE_DEF_SEL(buffer, + "buffer"); +_MTL_PRIVATE_DEF_SEL(bufferAlignment, + "bufferAlignment"); +_MTL_PRIVATE_DEF_SEL(bufferBytesPerRow, + "bufferBytesPerRow"); +_MTL_PRIVATE_DEF_SEL(bufferDataSize, + "bufferDataSize"); +_MTL_PRIVATE_DEF_SEL(bufferDataType, + "bufferDataType"); +_MTL_PRIVATE_DEF_SEL(bufferIndex, + "bufferIndex"); +_MTL_PRIVATE_DEF_SEL(bufferOffset, + "bufferOffset"); +_MTL_PRIVATE_DEF_SEL(bufferPointerType, + "bufferPointerType"); +_MTL_PRIVATE_DEF_SEL(bufferSize, + "bufferSize"); +_MTL_PRIVATE_DEF_SEL(bufferStructType, + "bufferStructType"); +_MTL_PRIVATE_DEF_SEL(buffers, + "buffers"); +_MTL_PRIVATE_DEF_SEL(buildAccelerationStructure_descriptor_scratchBuffer_, + "buildAccelerationStructure:descriptor:scratchBuffer:"); +_MTL_PRIVATE_DEF_SEL(buildAccelerationStructure_descriptor_scratchBuffer_scratchBufferOffset_, + "buildAccelerationStructure:descriptor:scratchBuffer:scratchBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(captureObject, + "captureObject"); +_MTL_PRIVATE_DEF_SEL(clearBarrier, + "clearBarrier"); +_MTL_PRIVATE_DEF_SEL(clearColor, + "clearColor"); +_MTL_PRIVATE_DEF_SEL(clearDepth, + "clearDepth"); +_MTL_PRIVATE_DEF_SEL(clearStencil, + "clearStencil"); +_MTL_PRIVATE_DEF_SEL(colorAttachmentMappingState, + "colorAttachmentMappingState"); +_MTL_PRIVATE_DEF_SEL(colorAttachments, + "colorAttachments"); +_MTL_PRIVATE_DEF_SEL(column, + "column"); +_MTL_PRIVATE_DEF_SEL(commandBuffer, + "commandBuffer"); +_MTL_PRIVATE_DEF_SEL(commandBufferWithDescriptor_, + "commandBufferWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(commandBufferWithUnretainedReferences, + "commandBufferWithUnretainedReferences"); +_MTL_PRIVATE_DEF_SEL(commandQueue, + "commandQueue"); +_MTL_PRIVATE_DEF_SEL(commandTypes, + "commandTypes"); +_MTL_PRIVATE_DEF_SEL(commit, + "commit"); +_MTL_PRIVATE_DEF_SEL(commit_count_, + "commit:count:"); +_MTL_PRIVATE_DEF_SEL(commit_count_options_, + "commit:count:options:"); +_MTL_PRIVATE_DEF_SEL(compareFunction, + "compareFunction"); +_MTL_PRIVATE_DEF_SEL(compileSymbolVisibility, + "compileSymbolVisibility"); +_MTL_PRIVATE_DEF_SEL(compiler, + "compiler"); +_MTL_PRIVATE_DEF_SEL(compressionType, + "compressionType"); +_MTL_PRIVATE_DEF_SEL(computeCommandEncoder, + "computeCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(computeCommandEncoderWithDescriptor_, + "computeCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(computeCommandEncoderWithDispatchType_, + "computeCommandEncoderWithDispatchType:"); +_MTL_PRIVATE_DEF_SEL(computeFunction, + "computeFunction"); +_MTL_PRIVATE_DEF_SEL(computeFunctionDescriptor, + "computeFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(computePassDescriptor, + "computePassDescriptor"); +_MTL_PRIVATE_DEF_SEL(concurrentDispatchThreadgroups_threadsPerThreadgroup_, + "concurrentDispatchThreadgroups:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(concurrentDispatchThreads_threadsPerThreadgroup_, + "concurrentDispatchThreads:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(configuration, + "configuration"); +_MTL_PRIVATE_DEF_SEL(constantBlockAlignment, + "constantBlockAlignment"); +_MTL_PRIVATE_DEF_SEL(constantDataAtIndex_, + "constantDataAtIndex:"); +_MTL_PRIVATE_DEF_SEL(constantValues, + "constantValues"); +_MTL_PRIVATE_DEF_SEL(containsAllocation_, + "containsAllocation:"); +_MTL_PRIVATE_DEF_SEL(contents, + "contents"); +_MTL_PRIVATE_DEF_SEL(controlDependencies, + "controlDependencies"); +_MTL_PRIVATE_DEF_SEL(controlPointBuffer, + "controlPointBuffer"); +_MTL_PRIVATE_DEF_SEL(controlPointBufferOffset, + "controlPointBufferOffset"); +_MTL_PRIVATE_DEF_SEL(controlPointBuffers, + "controlPointBuffers"); +_MTL_PRIVATE_DEF_SEL(controlPointCount, + "controlPointCount"); +_MTL_PRIVATE_DEF_SEL(controlPointFormat, + "controlPointFormat"); +_MTL_PRIVATE_DEF_SEL(controlPointStride, + "controlPointStride"); +_MTL_PRIVATE_DEF_SEL(convertSparsePixelRegions_toTileRegions_withTileSize_alignmentMode_numRegions_, + "convertSparsePixelRegions:toTileRegions:withTileSize:alignmentMode:numRegions:"); +_MTL_PRIVATE_DEF_SEL(convertSparseTileRegions_toPixelRegions_withTileSize_numRegions_, + "convertSparseTileRegions:toPixelRegions:withTileSize:numRegions:"); +_MTL_PRIVATE_DEF_SEL(copyAccelerationStructure_toAccelerationStructure_, + "copyAccelerationStructure:toAccelerationStructure:"); +_MTL_PRIVATE_DEF_SEL(copyAndCompactAccelerationStructure_toAccelerationStructure_, + "copyAndCompactAccelerationStructure:toAccelerationStructure:"); +_MTL_PRIVATE_DEF_SEL(copyBufferMappingsFromBuffer_toBuffer_operations_count_, + "copyBufferMappingsFromBuffer:toBuffer:operations:count:"); +_MTL_PRIVATE_DEF_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_, + "copyFromBuffer:sourceOffset:sourceBytesPerRow:sourceBytesPerImage:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:"); +_MTL_PRIVATE_DEF_SEL(copyFromBuffer_sourceOffset_sourceBytesPerRow_sourceBytesPerImage_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_options_, + "copyFromBuffer:sourceOffset:sourceBytesPerRow:sourceBytesPerImage:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:options:"); +_MTL_PRIVATE_DEF_SEL(copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_, + "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:"); +_MTL_PRIVATE_DEF_SEL(copyFromTensor_sourceOrigin_sourceDimensions_toTensor_destinationOrigin_destinationDimensions_, + "copyFromTensor:sourceOrigin:sourceDimensions:toTensor:destinationOrigin:destinationDimensions:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_, + "copyFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toBuffer:destinationOffset:destinationBytesPerRow:destinationBytesPerImage:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toBuffer_destinationOffset_destinationBytesPerRow_destinationBytesPerImage_options_, + "copyFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toBuffer:destinationOffset:destinationBytesPerRow:destinationBytesPerImage:options:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_, + "copyFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_sourceSlice_sourceLevel_toTexture_destinationSlice_destinationLevel_sliceCount_levelCount_, + "copyFromTexture:sourceSlice:sourceLevel:toTexture:destinationSlice:destinationLevel:sliceCount:levelCount:"); +_MTL_PRIVATE_DEF_SEL(copyFromTexture_toTexture_, + "copyFromTexture:toTexture:"); +_MTL_PRIVATE_DEF_SEL(copyIndirectCommandBuffer_sourceRange_destination_destinationIndex_, + "copyIndirectCommandBuffer:sourceRange:destination:destinationIndex:"); +_MTL_PRIVATE_DEF_SEL(copyParameterDataToBuffer_offset_, + "copyParameterDataToBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(copyResourceViewsFromPool_sourceRange_destinationIndex_, + "copyResourceViewsFromPool:sourceRange:destinationIndex:"); +_MTL_PRIVATE_DEF_SEL(copyStatusToBuffer_offset_, + "copyStatusToBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(copyTextureMappingsFromTexture_toTexture_operations_count_, + "copyTextureMappingsFromTexture:toTexture:operations:count:"); +_MTL_PRIVATE_DEF_SEL(count, + "count"); +_MTL_PRIVATE_DEF_SEL(counterSet, + "counterSet"); +_MTL_PRIVATE_DEF_SEL(counterSets, + "counterSets"); +_MTL_PRIVATE_DEF_SEL(counters, + "counters"); +_MTL_PRIVATE_DEF_SEL(cpuCacheMode, + "cpuCacheMode"); +_MTL_PRIVATE_DEF_SEL(currentAllocatedSize, + "currentAllocatedSize"); +_MTL_PRIVATE_DEF_SEL(curveBasis, + "curveBasis"); +_MTL_PRIVATE_DEF_SEL(curveEndCaps, + "curveEndCaps"); +_MTL_PRIVATE_DEF_SEL(curveType, + "curveType"); +_MTL_PRIVATE_DEF_SEL(data, + "data"); +_MTL_PRIVATE_DEF_SEL(dataSize, + "dataSize"); +_MTL_PRIVATE_DEF_SEL(dataType, + "dataType"); +_MTL_PRIVATE_DEF_SEL(dealloc, + "dealloc"); +_MTL_PRIVATE_DEF_SEL(debugLocation, + "debugLocation"); +_MTL_PRIVATE_DEF_SEL(debugSignposts, + "debugSignposts"); +_MTL_PRIVATE_DEF_SEL(defaultCaptureScope, + "defaultCaptureScope"); +_MTL_PRIVATE_DEF_SEL(defaultRasterSampleCount, + "defaultRasterSampleCount"); +_MTL_PRIVATE_DEF_SEL(depth, + "depth"); +_MTL_PRIVATE_DEF_SEL(depthAttachment, + "depthAttachment"); +_MTL_PRIVATE_DEF_SEL(depthAttachmentPixelFormat, + "depthAttachmentPixelFormat"); +_MTL_PRIVATE_DEF_SEL(depthCompareFunction, + "depthCompareFunction"); +_MTL_PRIVATE_DEF_SEL(depthFailureOperation, + "depthFailureOperation"); +_MTL_PRIVATE_DEF_SEL(depthPlane, + "depthPlane"); +_MTL_PRIVATE_DEF_SEL(depthResolveFilter, + "depthResolveFilter"); +_MTL_PRIVATE_DEF_SEL(depthStencilPassOperation, + "depthStencilPassOperation"); +_MTL_PRIVATE_DEF_SEL(descriptor, + "descriptor"); +_MTL_PRIVATE_DEF_SEL(destination, + "destination"); +_MTL_PRIVATE_DEF_SEL(destinationAlphaBlendFactor, + "destinationAlphaBlendFactor"); +_MTL_PRIVATE_DEF_SEL(destinationRGBBlendFactor, + "destinationRGBBlendFactor"); +_MTL_PRIVATE_DEF_SEL(device, + "device"); +_MTL_PRIVATE_DEF_SEL(didModifyRange_, + "didModifyRange:"); +_MTL_PRIVATE_DEF_SEL(dimensions, + "dimensions"); +_MTL_PRIVATE_DEF_SEL(dispatchNetworkWithIntermediatesHeap_, + "dispatchNetworkWithIntermediatesHeap:"); +_MTL_PRIVATE_DEF_SEL(dispatchQueue, + "dispatchQueue"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadgroups_threadsPerThreadgroup_, + "dispatchThreadgroups:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerThreadgroup_, + "dispatchThreadgroupsWithIndirectBuffer:indirectBufferOffset:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadgroupsWithIndirectBuffer_threadsPerThreadgroup_, + "dispatchThreadgroupsWithIndirectBuffer:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreads_threadsPerThreadgroup_, + "dispatchThreads:threadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadsPerTile_, + "dispatchThreadsPerTile:"); +_MTL_PRIVATE_DEF_SEL(dispatchThreadsWithIndirectBuffer_, + "dispatchThreadsWithIndirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(dispatchType, + "dispatchType"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPatches_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_indirectBuffer_indirectBufferOffset_, + "drawIndexedPatches:patchIndexBuffer:patchIndexBufferOffset:controlPointIndexBuffer:controlPointIndexBufferOffset:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_, + "drawIndexedPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:controlPointIndexBuffer:controlPointIndexBufferOffset:instanceCount:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_, + "drawIndexedPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:controlPointIndexBuffer:controlPointIndexBufferOffset:instanceCount:baseInstance:tessellationFactorBuffer:tessellationFactorBufferOffset:tessellationFactorBufferInstanceStride:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferLength:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferLength:instanceCount:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferLength_instanceCount_baseVertex_baseInstance_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferLength:instanceCount:baseVertex:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferOffset:instanceCount:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_baseVertex_baseInstance_, + "drawIndexedPrimitives:indexCount:indexType:indexBuffer:indexBufferOffset:instanceCount:baseVertex:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferLength_indirectBuffer_, + "drawIndexedPrimitives:indexType:indexBuffer:indexBufferLength:indirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferOffset_indirectBuffer_indirectBufferOffset_, + "drawIndexedPrimitives:indexType:indexBuffer:indexBufferOffset:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreadgroups:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreadgroupsWithIndirectBuffer:indirectBufferOffset:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreadgroupsWithIndirectBuffer_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreadgroupsWithIndirectBuffer:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_, + "drawMeshThreads:threadsPerObjectThreadgroup:threadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(drawPatches_patchIndexBuffer_patchIndexBufferOffset_indirectBuffer_indirectBufferOffset_, + "drawPatches:patchIndexBuffer:patchIndexBufferOffset:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_, + "drawPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:instanceCount:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_, + "drawPatches:patchStart:patchCount:patchIndexBuffer:patchIndexBufferOffset:instanceCount:baseInstance:tessellationFactorBuffer:tessellationFactorBufferOffset:tessellationFactorBufferInstanceStride:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_indirectBuffer_, + "drawPrimitives:indirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_indirectBuffer_indirectBufferOffset_, + "drawPrimitives:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_vertexStart_vertexCount_, + "drawPrimitives:vertexStart:vertexCount:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_, + "drawPrimitives:vertexStart:vertexCount:instanceCount:"); +_MTL_PRIVATE_DEF_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_, + "drawPrimitives:vertexStart:vertexCount:instanceCount:baseInstance:"); +_MTL_PRIVATE_DEF_SEL(drawableID, + "drawableID"); +_MTL_PRIVATE_DEF_SEL(elementArrayType, + "elementArrayType"); +_MTL_PRIVATE_DEF_SEL(elementIsArgumentBuffer, + "elementIsArgumentBuffer"); +_MTL_PRIVATE_DEF_SEL(elementPointerType, + "elementPointerType"); +_MTL_PRIVATE_DEF_SEL(elementStructType, + "elementStructType"); +_MTL_PRIVATE_DEF_SEL(elementTensorReferenceType, + "elementTensorReferenceType"); +_MTL_PRIVATE_DEF_SEL(elementTextureReferenceType, + "elementTextureReferenceType"); +_MTL_PRIVATE_DEF_SEL(elementType, + "elementType"); +_MTL_PRIVATE_DEF_SEL(enableLogging, + "enableLogging"); +_MTL_PRIVATE_DEF_SEL(encodeSignalEvent_value_, + "encodeSignalEvent:value:"); +_MTL_PRIVATE_DEF_SEL(encodeWaitForEvent_value_, + "encodeWaitForEvent:value:"); +_MTL_PRIVATE_DEF_SEL(encodedLength, + "encodedLength"); +_MTL_PRIVATE_DEF_SEL(encoderLabel, + "encoderLabel"); +_MTL_PRIVATE_DEF_SEL(endCommandBuffer, + "endCommandBuffer"); +_MTL_PRIVATE_DEF_SEL(endEncoding, + "endEncoding"); +_MTL_PRIVATE_DEF_SEL(endOfEncoderSampleIndex, + "endOfEncoderSampleIndex"); +_MTL_PRIVATE_DEF_SEL(endOfFragmentSampleIndex, + "endOfFragmentSampleIndex"); +_MTL_PRIVATE_DEF_SEL(endOfVertexSampleIndex, + "endOfVertexSampleIndex"); +_MTL_PRIVATE_DEF_SEL(endResidency, + "endResidency"); +_MTL_PRIVATE_DEF_SEL(enqueue, + "enqueue"); +_MTL_PRIVATE_DEF_SEL(enqueueBarrier, + "enqueueBarrier"); +_MTL_PRIVATE_DEF_SEL(error, + "error"); +_MTL_PRIVATE_DEF_SEL(errorOptions, + "errorOptions"); +_MTL_PRIVATE_DEF_SEL(errorState, + "errorState"); +_MTL_PRIVATE_DEF_SEL(executeCommandsInBuffer_indirectBuffer_, + "executeCommandsInBuffer:indirectBuffer:"); +_MTL_PRIVATE_DEF_SEL(executeCommandsInBuffer_indirectBuffer_indirectBufferOffset_, + "executeCommandsInBuffer:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(executeCommandsInBuffer_withRange_, + "executeCommandsInBuffer:withRange:"); +_MTL_PRIVATE_DEF_SEL(extentAtDimensionIndex_, + "extentAtDimensionIndex:"); +_MTL_PRIVATE_DEF_SEL(fastMathEnabled, + "fastMathEnabled"); +_MTL_PRIVATE_DEF_SEL(feedbackQueue, + "feedbackQueue"); +_MTL_PRIVATE_DEF_SEL(fillBuffer_range_value_, + "fillBuffer:range:value:"); +_MTL_PRIVATE_DEF_SEL(firstMipmapInTail, + "firstMipmapInTail"); +_MTL_PRIVATE_DEF_SEL(format, + "format"); +_MTL_PRIVATE_DEF_SEL(fragmentAdditionalBinaryFunctions, + "fragmentAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(fragmentArguments, + "fragmentArguments"); +_MTL_PRIVATE_DEF_SEL(fragmentBindings, + "fragmentBindings"); +_MTL_PRIVATE_DEF_SEL(fragmentBuffers, + "fragmentBuffers"); +_MTL_PRIVATE_DEF_SEL(fragmentFunction, + "fragmentFunction"); +_MTL_PRIVATE_DEF_SEL(fragmentFunctionDescriptor, + "fragmentFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(fragmentLinkedFunctions, + "fragmentLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(fragmentLinkingDescriptor, + "fragmentLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(fragmentPreloadedLibraries, + "fragmentPreloadedLibraries"); +_MTL_PRIVATE_DEF_SEL(fragmentStaticLinkingDescriptor, + "fragmentStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(frontFaceStencil, + "frontFaceStencil"); +_MTL_PRIVATE_DEF_SEL(function, + "function"); +_MTL_PRIVATE_DEF_SEL(functionConstantsDictionary, + "functionConstantsDictionary"); +_MTL_PRIVATE_DEF_SEL(functionCount, + "functionCount"); +_MTL_PRIVATE_DEF_SEL(functionDescriptor, + "functionDescriptor"); +_MTL_PRIVATE_DEF_SEL(functionDescriptors, + "functionDescriptors"); +_MTL_PRIVATE_DEF_SEL(functionGraph, + "functionGraph"); +_MTL_PRIVATE_DEF_SEL(functionGraphs, + "functionGraphs"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithBinaryFunction_, + "functionHandleWithBinaryFunction:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithBinaryFunction_stage_, + "functionHandleWithBinaryFunction:stage:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithFunction_, + "functionHandleWithFunction:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithFunction_stage_, + "functionHandleWithFunction:stage:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithName_, + "functionHandleWithName:"); +_MTL_PRIVATE_DEF_SEL(functionHandleWithName_stage_, + "functionHandleWithName:stage:"); +_MTL_PRIVATE_DEF_SEL(functionName, + "functionName"); +_MTL_PRIVATE_DEF_SEL(functionNames, + "functionNames"); +_MTL_PRIVATE_DEF_SEL(functionType, + "functionType"); +_MTL_PRIVATE_DEF_SEL(functions, + "functions"); +_MTL_PRIVATE_DEF_SEL(generateMipmapsForTexture_, + "generateMipmapsForTexture:"); +_MTL_PRIVATE_DEF_SEL(geometryDescriptors, + "geometryDescriptors"); +_MTL_PRIVATE_DEF_SEL(getBytes_bytesPerRow_bytesPerImage_fromRegion_mipmapLevel_slice_, + "getBytes:bytesPerRow:bytesPerImage:fromRegion:mipmapLevel:slice:"); +_MTL_PRIVATE_DEF_SEL(getBytes_bytesPerRow_fromRegion_mipmapLevel_, + "getBytes:bytesPerRow:fromRegion:mipmapLevel:"); +_MTL_PRIVATE_DEF_SEL(getBytes_strides_fromSliceOrigin_sliceDimensions_, + "getBytes:strides:fromSliceOrigin:sliceDimensions:"); +_MTL_PRIVATE_DEF_SEL(getDefaultSamplePositions_count_, + "getDefaultSamplePositions:count:"); +_MTL_PRIVATE_DEF_SEL(getPhysicalIndexForLogicalIndex_, + "getPhysicalIndexForLogicalIndex:"); +_MTL_PRIVATE_DEF_SEL(getSamplePositions_count_, + "getSamplePositions:count:"); +_MTL_PRIVATE_DEF_SEL(getTextureAccessCounters_region_mipLevel_slice_resetCounters_countersBuffer_countersBufferOffset_, + "getTextureAccessCounters:region:mipLevel:slice:resetCounters:countersBuffer:countersBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(gpuAddress, + "gpuAddress"); +_MTL_PRIVATE_DEF_SEL(gpuResourceID, + "gpuResourceID"); +_MTL_PRIVATE_DEF_SEL(groups, + "groups"); +_MTL_PRIVATE_DEF_SEL(hasUnifiedMemory, + "hasUnifiedMemory"); +_MTL_PRIVATE_DEF_SEL(hazardTrackingMode, + "hazardTrackingMode"); +_MTL_PRIVATE_DEF_SEL(heap, + "heap"); +_MTL_PRIVATE_DEF_SEL(heapAccelerationStructureSizeAndAlignWithDescriptor_, + "heapAccelerationStructureSizeAndAlignWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(heapAccelerationStructureSizeAndAlignWithSize_, + "heapAccelerationStructureSizeAndAlignWithSize:"); +_MTL_PRIVATE_DEF_SEL(heapBufferSizeAndAlignWithLength_options_, + "heapBufferSizeAndAlignWithLength:options:"); +_MTL_PRIVATE_DEF_SEL(heapOffset, + "heapOffset"); +_MTL_PRIVATE_DEF_SEL(heapTextureSizeAndAlignWithDescriptor_, + "heapTextureSizeAndAlignWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(height, + "height"); +_MTL_PRIVATE_DEF_SEL(horizontal, + "horizontal"); +_MTL_PRIVATE_DEF_SEL(horizontalSampleStorage, + "horizontalSampleStorage"); +_MTL_PRIVATE_DEF_SEL(imageblockMemoryLengthForDimensions_, + "imageblockMemoryLengthForDimensions:"); +_MTL_PRIVATE_DEF_SEL(imageblockSampleLength, + "imageblockSampleLength"); +_MTL_PRIVATE_DEF_SEL(index, + "index"); +_MTL_PRIVATE_DEF_SEL(indexBuffer, + "indexBuffer"); +_MTL_PRIVATE_DEF_SEL(indexBufferIndex, + "indexBufferIndex"); +_MTL_PRIVATE_DEF_SEL(indexBufferOffset, + "indexBufferOffset"); +_MTL_PRIVATE_DEF_SEL(indexType, + "indexType"); +_MTL_PRIVATE_DEF_SEL(indirectComputeCommandAtIndex_, + "indirectComputeCommandAtIndex:"); +_MTL_PRIVATE_DEF_SEL(indirectRenderCommandAtIndex_, + "indirectRenderCommandAtIndex:"); +_MTL_PRIVATE_DEF_SEL(inheritBuffers, + "inheritBuffers"); +_MTL_PRIVATE_DEF_SEL(inheritCullMode, + "inheritCullMode"); +_MTL_PRIVATE_DEF_SEL(inheritDepthBias, + "inheritDepthBias"); +_MTL_PRIVATE_DEF_SEL(inheritDepthClipMode, + "inheritDepthClipMode"); +_MTL_PRIVATE_DEF_SEL(inheritDepthStencilState, + "inheritDepthStencilState"); +_MTL_PRIVATE_DEF_SEL(inheritFrontFacingWinding, + "inheritFrontFacingWinding"); +_MTL_PRIVATE_DEF_SEL(inheritPipelineState, + "inheritPipelineState"); +_MTL_PRIVATE_DEF_SEL(inheritTriangleFillMode, + "inheritTriangleFillMode"); +_MTL_PRIVATE_DEF_SEL(init, + "init"); +_MTL_PRIVATE_DEF_SEL(initWithArgumentIndex_, + "initWithArgumentIndex:"); +_MTL_PRIVATE_DEF_SEL(initWithDispatchQueue_, + "initWithDispatchQueue:"); +_MTL_PRIVATE_DEF_SEL(initWithFunctionName_nodes_outputNode_attributes_, + "initWithFunctionName:nodes:outputNode:attributes:"); +_MTL_PRIVATE_DEF_SEL(initWithName_arguments_controlDependencies_, + "initWithName:arguments:controlDependencies:"); +_MTL_PRIVATE_DEF_SEL(initWithRank_values_, + "initWithRank:values:"); +_MTL_PRIVATE_DEF_SEL(initWithSampleCount_, + "initWithSampleCount:"); +_MTL_PRIVATE_DEF_SEL(initWithSampleCount_horizontal_vertical_, + "initWithSampleCount:horizontal:vertical:"); +_MTL_PRIVATE_DEF_SEL(initialCapacity, + "initialCapacity"); +_MTL_PRIVATE_DEF_SEL(initializeBindings, + "initializeBindings"); +_MTL_PRIVATE_DEF_SEL(inputDimensionsAtBufferIndex_, + "inputDimensionsAtBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(inputPrimitiveTopology, + "inputPrimitiveTopology"); +_MTL_PRIVATE_DEF_SEL(insertDebugCaptureBoundary, + "insertDebugCaptureBoundary"); +_MTL_PRIVATE_DEF_SEL(insertDebugSignpost_, + "insertDebugSignpost:"); +_MTL_PRIVATE_DEF_SEL(insertLibraries, + "insertLibraries"); +_MTL_PRIVATE_DEF_SEL(installName, + "installName"); +_MTL_PRIVATE_DEF_SEL(instanceCount, + "instanceCount"); +_MTL_PRIVATE_DEF_SEL(instanceCountBuffer, + "instanceCountBuffer"); +_MTL_PRIVATE_DEF_SEL(instanceCountBufferOffset, + "instanceCountBufferOffset"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorBuffer, + "instanceDescriptorBuffer"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorBufferOffset, + "instanceDescriptorBufferOffset"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorStride, + "instanceDescriptorStride"); +_MTL_PRIVATE_DEF_SEL(instanceDescriptorType, + "instanceDescriptorType"); +_MTL_PRIVATE_DEF_SEL(instanceTransformationMatrixLayout, + "instanceTransformationMatrixLayout"); +_MTL_PRIVATE_DEF_SEL(instancedAccelerationStructures, + "instancedAccelerationStructures"); +_MTL_PRIVATE_DEF_SEL(intermediatesHeapSize, + "intermediatesHeapSize"); +_MTL_PRIVATE_DEF_SEL(intersectionFunctionTableDescriptor, + "intersectionFunctionTableDescriptor"); +_MTL_PRIVATE_DEF_SEL(intersectionFunctionTableOffset, + "intersectionFunctionTableOffset"); +_MTL_PRIVATE_DEF_SEL(invalidateCounterRange_, + "invalidateCounterRange:"); +_MTL_PRIVATE_DEF_SEL(iosurface, + "iosurface"); +_MTL_PRIVATE_DEF_SEL(iosurfacePlane, + "iosurfacePlane"); +_MTL_PRIVATE_DEF_SEL(isActive, + "isActive"); +_MTL_PRIVATE_DEF_SEL(isAliasable, + "isAliasable"); +_MTL_PRIVATE_DEF_SEL(isAlphaToCoverageEnabled, + "isAlphaToCoverageEnabled"); +_MTL_PRIVATE_DEF_SEL(isAlphaToOneEnabled, + "isAlphaToOneEnabled"); +_MTL_PRIVATE_DEF_SEL(isArgument, + "isArgument"); +_MTL_PRIVATE_DEF_SEL(isBlendingEnabled, + "isBlendingEnabled"); +_MTL_PRIVATE_DEF_SEL(isCapturing, + "isCapturing"); +_MTL_PRIVATE_DEF_SEL(isDepth24Stencil8PixelFormatSupported, + "isDepth24Stencil8PixelFormatSupported"); +_MTL_PRIVATE_DEF_SEL(isDepthTexture, + "isDepthTexture"); +_MTL_PRIVATE_DEF_SEL(isDepthWriteEnabled, + "isDepthWriteEnabled"); +_MTL_PRIVATE_DEF_SEL(isFramebufferOnly, + "isFramebufferOnly"); +_MTL_PRIVATE_DEF_SEL(isHeadless, + "isHeadless"); +_MTL_PRIVATE_DEF_SEL(isLowPower, + "isLowPower"); +_MTL_PRIVATE_DEF_SEL(isPatchControlPointData, + "isPatchControlPointData"); +_MTL_PRIVATE_DEF_SEL(isPatchData, + "isPatchData"); +_MTL_PRIVATE_DEF_SEL(isRasterizationEnabled, + "isRasterizationEnabled"); +_MTL_PRIVATE_DEF_SEL(isRemovable, + "isRemovable"); +_MTL_PRIVATE_DEF_SEL(isShareable, + "isShareable"); +_MTL_PRIVATE_DEF_SEL(isSparse, + "isSparse"); +_MTL_PRIVATE_DEF_SEL(isTessellationFactorScaleEnabled, + "isTessellationFactorScaleEnabled"); +_MTL_PRIVATE_DEF_SEL(isUsed, + "isUsed"); +_MTL_PRIVATE_DEF_SEL(kernelEndTime, + "kernelEndTime"); +_MTL_PRIVATE_DEF_SEL(kernelStartTime, + "kernelStartTime"); +_MTL_PRIVATE_DEF_SEL(label, + "label"); +_MTL_PRIVATE_DEF_SEL(languageVersion, + "languageVersion"); +_MTL_PRIVATE_DEF_SEL(layerAtIndex_, + "layerAtIndex:"); +_MTL_PRIVATE_DEF_SEL(layerCount, + "layerCount"); +_MTL_PRIVATE_DEF_SEL(layers, + "layers"); +_MTL_PRIVATE_DEF_SEL(layouts, + "layouts"); +_MTL_PRIVATE_DEF_SEL(length, + "length"); +_MTL_PRIVATE_DEF_SEL(level, + "level"); +_MTL_PRIVATE_DEF_SEL(levelRange, + "levelRange"); +_MTL_PRIVATE_DEF_SEL(libraries, + "libraries"); +_MTL_PRIVATE_DEF_SEL(library, + "library"); +_MTL_PRIVATE_DEF_SEL(libraryType, + "libraryType"); +_MTL_PRIVATE_DEF_SEL(line, + "line"); +_MTL_PRIVATE_DEF_SEL(linkedFunctions, + "linkedFunctions"); +_MTL_PRIVATE_DEF_SEL(loadAction, + "loadAction"); +_MTL_PRIVATE_DEF_SEL(loadBuffer_offset_size_sourceHandle_sourceHandleOffset_, + "loadBuffer:offset:size:sourceHandle:sourceHandleOffset:"); +_MTL_PRIVATE_DEF_SEL(loadBytes_size_sourceHandle_sourceHandleOffset_, + "loadBytes:size:sourceHandle:sourceHandleOffset:"); +_MTL_PRIVATE_DEF_SEL(loadTexture_slice_level_size_sourceBytesPerRow_sourceBytesPerImage_destinationOrigin_sourceHandle_sourceHandleOffset_, + "loadTexture:slice:level:size:sourceBytesPerRow:sourceBytesPerImage:destinationOrigin:sourceHandle:sourceHandleOffset:"); +_MTL_PRIVATE_DEF_SEL(location, + "location"); +_MTL_PRIVATE_DEF_SEL(locationNumber, + "locationNumber"); +_MTL_PRIVATE_DEF_SEL(lodAverage, + "lodAverage"); +_MTL_PRIVATE_DEF_SEL(lodBias, + "lodBias"); +_MTL_PRIVATE_DEF_SEL(lodMaxClamp, + "lodMaxClamp"); +_MTL_PRIVATE_DEF_SEL(lodMinClamp, + "lodMinClamp"); +_MTL_PRIVATE_DEF_SEL(logState, + "logState"); +_MTL_PRIVATE_DEF_SEL(logs, + "logs"); +_MTL_PRIVATE_DEF_SEL(lookupArchives, + "lookupArchives"); +_MTL_PRIVATE_DEF_SEL(machineLearningCommandEncoder, + "machineLearningCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(machineLearningFunctionDescriptor, + "machineLearningFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(magFilter, + "magFilter"); +_MTL_PRIVATE_DEF_SEL(makeAliasable, + "makeAliasable"); +_MTL_PRIVATE_DEF_SEL(mapPhysicalToScreenCoordinates_forLayer_, + "mapPhysicalToScreenCoordinates:forLayer:"); +_MTL_PRIVATE_DEF_SEL(mapScreenToPhysicalCoordinates_forLayer_, + "mapScreenToPhysicalCoordinates:forLayer:"); +_MTL_PRIVATE_DEF_SEL(mathFloatingPointFunctions, + "mathFloatingPointFunctions"); +_MTL_PRIVATE_DEF_SEL(mathMode, + "mathMode"); +_MTL_PRIVATE_DEF_SEL(maxAnisotropy, + "maxAnisotropy"); +_MTL_PRIVATE_DEF_SEL(maxArgumentBufferSamplerCount, + "maxArgumentBufferSamplerCount"); +_MTL_PRIVATE_DEF_SEL(maxAvailableSizeWithAlignment_, + "maxAvailableSizeWithAlignment:"); +_MTL_PRIVATE_DEF_SEL(maxBufferBindCount, + "maxBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxBufferLength, + "maxBufferLength"); +_MTL_PRIVATE_DEF_SEL(maxCallStackDepth, + "maxCallStackDepth"); +_MTL_PRIVATE_DEF_SEL(maxCommandBufferCount, + "maxCommandBufferCount"); +_MTL_PRIVATE_DEF_SEL(maxCommandsInFlight, + "maxCommandsInFlight"); +_MTL_PRIVATE_DEF_SEL(maxCompatiblePlacementSparsePageSize, + "maxCompatiblePlacementSparsePageSize"); +_MTL_PRIVATE_DEF_SEL(maxFragmentBufferBindCount, + "maxFragmentBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxFragmentCallStackDepth, + "maxFragmentCallStackDepth"); +_MTL_PRIVATE_DEF_SEL(maxInstanceCount, + "maxInstanceCount"); +_MTL_PRIVATE_DEF_SEL(maxKernelBufferBindCount, + "maxKernelBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxKernelThreadgroupMemoryBindCount, + "maxKernelThreadgroupMemoryBindCount"); +_MTL_PRIVATE_DEF_SEL(maxMeshBufferBindCount, + "maxMeshBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxMotionTransformCount, + "maxMotionTransformCount"); +_MTL_PRIVATE_DEF_SEL(maxObjectBufferBindCount, + "maxObjectBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxObjectThreadgroupMemoryBindCount, + "maxObjectThreadgroupMemoryBindCount"); +_MTL_PRIVATE_DEF_SEL(maxSampleCount, + "maxSampleCount"); +_MTL_PRIVATE_DEF_SEL(maxSamplerStateBindCount, + "maxSamplerStateBindCount"); +_MTL_PRIVATE_DEF_SEL(maxTessellationFactor, + "maxTessellationFactor"); +_MTL_PRIVATE_DEF_SEL(maxTextureBindCount, + "maxTextureBindCount"); +_MTL_PRIVATE_DEF_SEL(maxThreadgroupMemoryLength, + "maxThreadgroupMemoryLength"); +_MTL_PRIVATE_DEF_SEL(maxThreadsPerThreadgroup, + "maxThreadsPerThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadgroupsPerMeshGrid, + "maxTotalThreadgroupsPerMeshGrid"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadsPerMeshThreadgroup, + "maxTotalThreadsPerMeshThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadsPerObjectThreadgroup, + "maxTotalThreadsPerObjectThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTotalThreadsPerThreadgroup, + "maxTotalThreadsPerThreadgroup"); +_MTL_PRIVATE_DEF_SEL(maxTransferRate, + "maxTransferRate"); +_MTL_PRIVATE_DEF_SEL(maxVertexAmplificationCount, + "maxVertexAmplificationCount"); +_MTL_PRIVATE_DEF_SEL(maxVertexBufferBindCount, + "maxVertexBufferBindCount"); +_MTL_PRIVATE_DEF_SEL(maxVertexCallStackDepth, + "maxVertexCallStackDepth"); +_MTL_PRIVATE_DEF_SEL(maximumConcurrentCompilationTaskCount, + "maximumConcurrentCompilationTaskCount"); +_MTL_PRIVATE_DEF_SEL(memberByName_, + "memberByName:"); +_MTL_PRIVATE_DEF_SEL(members, + "members"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithResources_count_, + "memoryBarrierWithResources:count:"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithResources_count_afterStages_beforeStages_, + "memoryBarrierWithResources:count:afterStages:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithScope_, + "memoryBarrierWithScope:"); +_MTL_PRIVATE_DEF_SEL(memoryBarrierWithScope_afterStages_beforeStages_, + "memoryBarrierWithScope:afterStages:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(meshAdditionalBinaryFunctions, + "meshAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(meshBindings, + "meshBindings"); +_MTL_PRIVATE_DEF_SEL(meshBuffers, + "meshBuffers"); +_MTL_PRIVATE_DEF_SEL(meshFunction, + "meshFunction"); +_MTL_PRIVATE_DEF_SEL(meshFunctionDescriptor, + "meshFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(meshLinkedFunctions, + "meshLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(meshLinkingDescriptor, + "meshLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(meshStaticLinkingDescriptor, + "meshStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(meshThreadExecutionWidth, + "meshThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(meshThreadgroupSizeIsMultipleOfThreadExecutionWidth, + "meshThreadgroupSizeIsMultipleOfThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(minFilter, + "minFilter"); +_MTL_PRIVATE_DEF_SEL(minimumLinearTextureAlignmentForPixelFormat_, + "minimumLinearTextureAlignmentForPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(minimumTextureBufferAlignmentForPixelFormat_, + "minimumTextureBufferAlignmentForPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(mipFilter, + "mipFilter"); +_MTL_PRIVATE_DEF_SEL(mipmapLevelCount, + "mipmapLevelCount"); +_MTL_PRIVATE_DEF_SEL(motionEndBorderMode, + "motionEndBorderMode"); +_MTL_PRIVATE_DEF_SEL(motionEndTime, + "motionEndTime"); +_MTL_PRIVATE_DEF_SEL(motionKeyframeCount, + "motionKeyframeCount"); +_MTL_PRIVATE_DEF_SEL(motionStartBorderMode, + "motionStartBorderMode"); +_MTL_PRIVATE_DEF_SEL(motionStartTime, + "motionStartTime"); +_MTL_PRIVATE_DEF_SEL(motionTransformBuffer, + "motionTransformBuffer"); +_MTL_PRIVATE_DEF_SEL(motionTransformBufferOffset, + "motionTransformBufferOffset"); +_MTL_PRIVATE_DEF_SEL(motionTransformCount, + "motionTransformCount"); +_MTL_PRIVATE_DEF_SEL(motionTransformCountBuffer, + "motionTransformCountBuffer"); +_MTL_PRIVATE_DEF_SEL(motionTransformCountBufferOffset, + "motionTransformCountBufferOffset"); +_MTL_PRIVATE_DEF_SEL(motionTransformStride, + "motionTransformStride"); +_MTL_PRIVATE_DEF_SEL(motionTransformType, + "motionTransformType"); +_MTL_PRIVATE_DEF_SEL(moveTextureMappingsFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_, + "moveTextureMappingsFromTexture:sourceSlice:sourceLevel:sourceOrigin:sourceSize:toTexture:destinationSlice:destinationLevel:destinationOrigin:"); +_MTL_PRIVATE_DEF_SEL(mutability, + "mutability"); +_MTL_PRIVATE_DEF_SEL(name, + "name"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithDescriptor_, + "newAccelerationStructureWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithDescriptor_offset_, + "newAccelerationStructureWithDescriptor:offset:"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithSize_, + "newAccelerationStructureWithSize:"); +_MTL_PRIVATE_DEF_SEL(newAccelerationStructureWithSize_offset_, + "newAccelerationStructureWithSize:offset:"); +_MTL_PRIVATE_DEF_SEL(newArchiveWithURL_error_, + "newArchiveWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderForBufferAtIndex_, + "newArgumentEncoderForBufferAtIndex:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithArguments_, + "newArgumentEncoderWithArguments:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithBufferBinding_, + "newArgumentEncoderWithBufferBinding:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithBufferIndex_, + "newArgumentEncoderWithBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(newArgumentEncoderWithBufferIndex_reflection_, + "newArgumentEncoderWithBufferIndex:reflection:"); +_MTL_PRIVATE_DEF_SEL(newArgumentTableWithDescriptor_error_, + "newArgumentTableWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newBinaryArchiveWithDescriptor_error_, + "newBinaryArchiveWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_completionHandler_, + "newBinaryFunctionWithDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newBinaryFunctionWithDescriptor_compilerTaskOptions_error_, + "newBinaryFunctionWithDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newBinaryFunctionWithDescriptor_error_, + "newBinaryFunctionWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithBytes_length_options_, + "newBufferWithBytes:length:options:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithBytesNoCopy_length_options_deallocator_, + "newBufferWithBytesNoCopy:length:options:deallocator:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithLength_options_, + "newBufferWithLength:options:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithLength_options_offset_, + "newBufferWithLength:options:offset:"); +_MTL_PRIVATE_DEF_SEL(newBufferWithLength_options_placementSparsePageSize_, + "newBufferWithLength:options:placementSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(newCaptureScopeWithCommandQueue_, + "newCaptureScopeWithCommandQueue:"); +_MTL_PRIVATE_DEF_SEL(newCaptureScopeWithDevice_, + "newCaptureScopeWithDevice:"); +_MTL_PRIVATE_DEF_SEL(newCaptureScopeWithMTL4CommandQueue_, + "newCaptureScopeWithMTL4CommandQueue:"); +_MTL_PRIVATE_DEF_SEL(newCommandAllocator, + "newCommandAllocator"); +_MTL_PRIVATE_DEF_SEL(newCommandAllocatorWithDescriptor_error_, + "newCommandAllocatorWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newCommandBuffer, + "newCommandBuffer"); +_MTL_PRIVATE_DEF_SEL(newCommandQueue, + "newCommandQueue"); +_MTL_PRIVATE_DEF_SEL(newCommandQueueWithDescriptor_, + "newCommandQueueWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newCommandQueueWithMaxCommandBufferCount_, + "newCommandQueueWithMaxCommandBufferCount:"); +_MTL_PRIVATE_DEF_SEL(newCompilerWithDescriptor_error_, + "newCompilerWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithAdditionalBinaryFunctions_error_, + "newComputePipelineStateWithAdditionalBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithBinaryFunctions_error_, + "newComputePipelineStateWithBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_completionHandler_, + "newComputePipelineStateWithDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_compilerTaskOptions_error_, + "newComputePipelineStateWithDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_, + "newComputePipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_, + "newComputePipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_dynamicLinkingDescriptor_error_, + "newComputePipelineStateWithDescriptor:dynamicLinkingDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_error_, + "newComputePipelineStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_options_completionHandler_, + "newComputePipelineStateWithDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithDescriptor_options_reflection_error_, + "newComputePipelineStateWithDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_completionHandler_, + "newComputePipelineStateWithFunction:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_error_, + "newComputePipelineStateWithFunction:error:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_options_completionHandler_, + "newComputePipelineStateWithFunction:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newComputePipelineStateWithFunction_options_reflection_error_, + "newComputePipelineStateWithFunction:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newCounterHeapWithDescriptor_error_, + "newCounterHeapWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newCounterSampleBufferWithDescriptor_error_, + "newCounterSampleBufferWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newDefaultLibrary, + "newDefaultLibrary"); +_MTL_PRIVATE_DEF_SEL(newDefaultLibraryWithBundle_error_, + "newDefaultLibraryWithBundle:error:"); +_MTL_PRIVATE_DEF_SEL(newDepthStencilStateWithDescriptor_, + "newDepthStencilStateWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibrary_completionHandler_, + "newDynamicLibrary:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibrary_error_, + "newDynamicLibrary:error:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibraryWithURL_completionHandler_, + "newDynamicLibraryWithURL:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newDynamicLibraryWithURL_error_, + "newDynamicLibraryWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newEvent, + "newEvent"); +_MTL_PRIVATE_DEF_SEL(newFence, + "newFence"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithDescriptor_completionHandler_, + "newFunctionWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithDescriptor_error_, + "newFunctionWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithName_, + "newFunctionWithName:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithName_constantValues_completionHandler_, + "newFunctionWithName:constantValues:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newFunctionWithName_constantValues_error_, + "newFunctionWithName:constantValues:error:"); +_MTL_PRIVATE_DEF_SEL(newHeapWithDescriptor_, + "newHeapWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newIOCommandQueueWithDescriptor_error_, + "newIOCommandQueueWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newIOFileHandleWithURL_compressionMethod_error_, + "newIOFileHandleWithURL:compressionMethod:error:"); +_MTL_PRIVATE_DEF_SEL(newIOFileHandleWithURL_error_, + "newIOFileHandleWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newIOHandleWithURL_compressionMethod_error_, + "newIOHandleWithURL:compressionMethod:error:"); +_MTL_PRIVATE_DEF_SEL(newIOHandleWithURL_error_, + "newIOHandleWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newIndirectCommandBufferWithDescriptor_maxCommandCount_options_, + "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionTableWithDescriptor_, + "newIntersectionFunctionTableWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionTableWithDescriptor_stage_, + "newIntersectionFunctionTableWithDescriptor:stage:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionWithDescriptor_completionHandler_, + "newIntersectionFunctionWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newIntersectionFunctionWithDescriptor_error_, + "newIntersectionFunctionWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithData_error_, + "newLibraryWithData:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithDescriptor_completionHandler_, + "newLibraryWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithDescriptor_error_, + "newLibraryWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithFile_error_, + "newLibraryWithFile:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithSource_options_completionHandler_, + "newLibraryWithSource:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithSource_options_error_, + "newLibraryWithSource:options:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithStitchedDescriptor_completionHandler_, + "newLibraryWithStitchedDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithStitchedDescriptor_error_, + "newLibraryWithStitchedDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newLibraryWithURL_error_, + "newLibraryWithURL:error:"); +_MTL_PRIVATE_DEF_SEL(newLogStateWithDescriptor_error_, + "newLogStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newMTL4CommandQueue, + "newMTL4CommandQueue"); +_MTL_PRIVATE_DEF_SEL(newMTL4CommandQueueWithDescriptor_error_, + "newMTL4CommandQueueWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newMachineLearningPipelineStateWithDescriptor_completionHandler_, + "newMachineLearningPipelineStateWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newMachineLearningPipelineStateWithDescriptor_error_, + "newMachineLearningPipelineStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newPipelineDataSetSerializerWithDescriptor_, + "newPipelineDataSetSerializerWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newRasterizationRateMapWithDescriptor_, + "newRasterizationRateMapWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newRemoteBufferViewForDevice_, + "newRemoteBufferViewForDevice:"); +_MTL_PRIVATE_DEF_SEL(newRemoteTextureViewForDevice_, + "newRemoteTextureViewForDevice:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineDescriptorForSpecialization, + "newRenderPipelineDescriptorForSpecialization"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_completionHandler_, + "newRenderPipelineStateBySpecializationWithDescriptor:pipeline:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateBySpecializationWithDescriptor_pipeline_error_, + "newRenderPipelineStateBySpecializationWithDescriptor:pipeline:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithAdditionalBinaryFunctions_error_, + "newRenderPipelineStateWithAdditionalBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithBinaryFunctions_error_, + "newRenderPipelineStateWithBinaryFunctions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_completionHandler_, + "newRenderPipelineStateWithDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_compilerTaskOptions_error_, + "newRenderPipelineStateWithDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_completionHandler_, + "newRenderPipelineStateWithDescriptor:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_completionHandler_, + "newRenderPipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_compilerTaskOptions_error_, + "newRenderPipelineStateWithDescriptor:dynamicLinkingDescriptor:compilerTaskOptions:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_dynamicLinkingDescriptor_error_, + "newRenderPipelineStateWithDescriptor:dynamicLinkingDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_error_, + "newRenderPipelineStateWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_options_completionHandler_, + "newRenderPipelineStateWithDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithDescriptor_options_reflection_error_, + "newRenderPipelineStateWithDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithMeshDescriptor_options_completionHandler_, + "newRenderPipelineStateWithMeshDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithMeshDescriptor_options_reflection_error_, + "newRenderPipelineStateWithMeshDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithTileDescriptor_options_completionHandler_, + "newRenderPipelineStateWithTileDescriptor:options:completionHandler:"); +_MTL_PRIVATE_DEF_SEL(newRenderPipelineStateWithTileDescriptor_options_reflection_error_, + "newRenderPipelineStateWithTileDescriptor:options:reflection:error:"); +_MTL_PRIVATE_DEF_SEL(newResidencySetWithDescriptor_error_, + "newResidencySetWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newSamplerStateWithDescriptor_, + "newSamplerStateWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newScratchBufferWithMinimumSize_, + "newScratchBufferWithMinimumSize:"); +_MTL_PRIVATE_DEF_SEL(newSharedEvent, + "newSharedEvent"); +_MTL_PRIVATE_DEF_SEL(newSharedEventHandle, + "newSharedEventHandle"); +_MTL_PRIVATE_DEF_SEL(newSharedEventWithHandle_, + "newSharedEventWithHandle:"); +_MTL_PRIVATE_DEF_SEL(newSharedTextureHandle, + "newSharedTextureHandle"); +_MTL_PRIVATE_DEF_SEL(newSharedTextureWithDescriptor_, + "newSharedTextureWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newSharedTextureWithHandle_, + "newSharedTextureWithHandle:"); +_MTL_PRIVATE_DEF_SEL(newTensorWithDescriptor_error_, + "newTensorWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newTensorWithDescriptor_offset_error_, + "newTensorWithDescriptor:offset:error:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewPoolWithDescriptor_error_, + "newTextureViewPoolWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithDescriptor_, + "newTextureViewWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithPixelFormat_, + "newTextureViewWithPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_, + "newTextureViewWithPixelFormat:textureType:levels:slices:"); +_MTL_PRIVATE_DEF_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_swizzle_, + "newTextureViewWithPixelFormat:textureType:levels:slices:swizzle:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_, + "newTextureWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_iosurface_plane_, + "newTextureWithDescriptor:iosurface:plane:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_offset_, + "newTextureWithDescriptor:offset:"); +_MTL_PRIVATE_DEF_SEL(newTextureWithDescriptor_offset_bytesPerRow_, + "newTextureWithDescriptor:offset:bytesPerRow:"); +_MTL_PRIVATE_DEF_SEL(newVisibleFunctionTableWithDescriptor_, + "newVisibleFunctionTableWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(newVisibleFunctionTableWithDescriptor_stage_, + "newVisibleFunctionTableWithDescriptor:stage:"); +_MTL_PRIVATE_DEF_SEL(nodes, + "nodes"); +_MTL_PRIVATE_DEF_SEL(normalizedCoordinates, + "normalizedCoordinates"); +_MTL_PRIVATE_DEF_SEL(notifyListener_atValue_block_, + "notifyListener:atValue:block:"); +_MTL_PRIVATE_DEF_SEL(objectAdditionalBinaryFunctions, + "objectAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(objectAtIndexedSubscript_, + "objectAtIndexedSubscript:"); +_MTL_PRIVATE_DEF_SEL(objectBindings, + "objectBindings"); +_MTL_PRIVATE_DEF_SEL(objectBuffers, + "objectBuffers"); +_MTL_PRIVATE_DEF_SEL(objectFunction, + "objectFunction"); +_MTL_PRIVATE_DEF_SEL(objectFunctionDescriptor, + "objectFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(objectLinkedFunctions, + "objectLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(objectLinkingDescriptor, + "objectLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(objectPayloadAlignment, + "objectPayloadAlignment"); +_MTL_PRIVATE_DEF_SEL(objectPayloadDataSize, + "objectPayloadDataSize"); +_MTL_PRIVATE_DEF_SEL(objectStaticLinkingDescriptor, + "objectStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(objectThreadExecutionWidth, + "objectThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(objectThreadgroupSizeIsMultipleOfThreadExecutionWidth, + "objectThreadgroupSizeIsMultipleOfThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(offset, + "offset"); +_MTL_PRIVATE_DEF_SEL(opaque, + "opaque"); +_MTL_PRIVATE_DEF_SEL(optimizationLevel, + "optimizationLevel"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForCPUAccess_, + "optimizeContentsForCPUAccess:"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForCPUAccess_slice_level_, + "optimizeContentsForCPUAccess:slice:level:"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForGPUAccess_, + "optimizeContentsForGPUAccess:"); +_MTL_PRIVATE_DEF_SEL(optimizeContentsForGPUAccess_slice_level_, + "optimizeContentsForGPUAccess:slice:level:"); +_MTL_PRIVATE_DEF_SEL(optimizeIndirectCommandBuffer_withRange_, + "optimizeIndirectCommandBuffer:withRange:"); +_MTL_PRIVATE_DEF_SEL(options, + "options"); +_MTL_PRIVATE_DEF_SEL(outputNode, + "outputNode"); +_MTL_PRIVATE_DEF_SEL(outputURL, + "outputURL"); +_MTL_PRIVATE_DEF_SEL(parallelRenderCommandEncoderWithDescriptor_, + "parallelRenderCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(parameterBufferSizeAndAlign, + "parameterBufferSizeAndAlign"); +_MTL_PRIVATE_DEF_SEL(parentRelativeLevel, + "parentRelativeLevel"); +_MTL_PRIVATE_DEF_SEL(parentRelativeSlice, + "parentRelativeSlice"); +_MTL_PRIVATE_DEF_SEL(parentTexture, + "parentTexture"); +_MTL_PRIVATE_DEF_SEL(patchControlPointCount, + "patchControlPointCount"); +_MTL_PRIVATE_DEF_SEL(patchType, + "patchType"); +_MTL_PRIVATE_DEF_SEL(payloadMemoryLength, + "payloadMemoryLength"); +_MTL_PRIVATE_DEF_SEL(peerCount, + "peerCount"); +_MTL_PRIVATE_DEF_SEL(peerGroupID, + "peerGroupID"); +_MTL_PRIVATE_DEF_SEL(peerIndex, + "peerIndex"); +_MTL_PRIVATE_DEF_SEL(physicalGranularity, + "physicalGranularity"); +_MTL_PRIVATE_DEF_SEL(physicalSizeForLayer_, + "physicalSizeForLayer:"); +_MTL_PRIVATE_DEF_SEL(pipelineDataSetSerializer, + "pipelineDataSetSerializer"); +_MTL_PRIVATE_DEF_SEL(pixelFormat, + "pixelFormat"); +_MTL_PRIVATE_DEF_SEL(placementSparsePageSize, + "placementSparsePageSize"); +_MTL_PRIVATE_DEF_SEL(pointerType, + "pointerType"); +_MTL_PRIVATE_DEF_SEL(popDebugGroup, + "popDebugGroup"); +_MTL_PRIVATE_DEF_SEL(preloadedLibraries, + "preloadedLibraries"); +_MTL_PRIVATE_DEF_SEL(preprocessorMacros, + "preprocessorMacros"); +_MTL_PRIVATE_DEF_SEL(present, + "present"); +_MTL_PRIVATE_DEF_SEL(presentAfterMinimumDuration_, + "presentAfterMinimumDuration:"); +_MTL_PRIVATE_DEF_SEL(presentAtTime_, + "presentAtTime:"); +_MTL_PRIVATE_DEF_SEL(presentDrawable_, + "presentDrawable:"); +_MTL_PRIVATE_DEF_SEL(presentDrawable_afterMinimumDuration_, + "presentDrawable:afterMinimumDuration:"); +_MTL_PRIVATE_DEF_SEL(presentDrawable_atTime_, + "presentDrawable:atTime:"); +_MTL_PRIVATE_DEF_SEL(presentedTime, + "presentedTime"); +_MTL_PRIVATE_DEF_SEL(preserveInvariance, + "preserveInvariance"); +_MTL_PRIVATE_DEF_SEL(primitiveDataBuffer, + "primitiveDataBuffer"); +_MTL_PRIVATE_DEF_SEL(primitiveDataBufferOffset, + "primitiveDataBufferOffset"); +_MTL_PRIVATE_DEF_SEL(primitiveDataElementSize, + "primitiveDataElementSize"); +_MTL_PRIVATE_DEF_SEL(primitiveDataStride, + "primitiveDataStride"); +_MTL_PRIVATE_DEF_SEL(priority, + "priority"); +_MTL_PRIVATE_DEF_SEL(privateFunctionDescriptors, + "privateFunctionDescriptors"); +_MTL_PRIVATE_DEF_SEL(privateFunctions, + "privateFunctions"); +_MTL_PRIVATE_DEF_SEL(pushDebugGroup_, + "pushDebugGroup:"); +_MTL_PRIVATE_DEF_SEL(queryTimestampFrequency, + "queryTimestampFrequency"); +_MTL_PRIVATE_DEF_SEL(rAddressMode, + "rAddressMode"); +_MTL_PRIVATE_DEF_SEL(radiusBuffer, + "radiusBuffer"); +_MTL_PRIVATE_DEF_SEL(radiusBufferOffset, + "radiusBufferOffset"); +_MTL_PRIVATE_DEF_SEL(radiusBuffers, + "radiusBuffers"); +_MTL_PRIVATE_DEF_SEL(radiusFormat, + "radiusFormat"); +_MTL_PRIVATE_DEF_SEL(radiusStride, + "radiusStride"); +_MTL_PRIVATE_DEF_SEL(rank, + "rank"); +_MTL_PRIVATE_DEF_SEL(rasterSampleCount, + "rasterSampleCount"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMap, + "rasterizationRateMap"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMapDescriptorWithScreenSize_, + "rasterizationRateMapDescriptorWithScreenSize:"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMapDescriptorWithScreenSize_layer_, + "rasterizationRateMapDescriptorWithScreenSize:layer:"); +_MTL_PRIVATE_DEF_SEL(rasterizationRateMapDescriptorWithScreenSize_layerCount_layers_, + "rasterizationRateMapDescriptorWithScreenSize:layerCount:layers:"); +_MTL_PRIVATE_DEF_SEL(readMask, + "readMask"); +_MTL_PRIVATE_DEF_SEL(readWriteTextureSupport, + "readWriteTextureSupport"); +_MTL_PRIVATE_DEF_SEL(recommendedMaxWorkingSetSize, + "recommendedMaxWorkingSetSize"); +_MTL_PRIVATE_DEF_SEL(reductionMode, + "reductionMode"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_options_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:options:"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:scratchBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(refitAccelerationStructure_descriptor_destination_scratchBuffer_scratchBufferOffset_options_, + "refitAccelerationStructure:descriptor:destination:scratchBuffer:scratchBufferOffset:options:"); +_MTL_PRIVATE_DEF_SEL(reflection, + "reflection"); +_MTL_PRIVATE_DEF_SEL(reflectionForFunctionWithName_, + "reflectionForFunctionWithName:"); +_MTL_PRIVATE_DEF_SEL(registryID, + "registryID"); +_MTL_PRIVATE_DEF_SEL(remoteStorageBuffer, + "remoteStorageBuffer"); +_MTL_PRIVATE_DEF_SEL(remoteStorageTexture, + "remoteStorageTexture"); +_MTL_PRIVATE_DEF_SEL(removeAllAllocations, + "removeAllAllocations"); +_MTL_PRIVATE_DEF_SEL(removeAllDebugMarkers, + "removeAllDebugMarkers"); +_MTL_PRIVATE_DEF_SEL(removeAllocation_, + "removeAllocation:"); +_MTL_PRIVATE_DEF_SEL(removeAllocations_count_, + "removeAllocations:count:"); +_MTL_PRIVATE_DEF_SEL(removeResidencySet_, + "removeResidencySet:"); +_MTL_PRIVATE_DEF_SEL(removeResidencySets_count_, + "removeResidencySets:count:"); +_MTL_PRIVATE_DEF_SEL(renderCommandEncoder, + "renderCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(renderCommandEncoderWithDescriptor_, + "renderCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(renderCommandEncoderWithDescriptor_options_, + "renderCommandEncoderWithDescriptor:options:"); +_MTL_PRIVATE_DEF_SEL(renderPassDescriptor, + "renderPassDescriptor"); +_MTL_PRIVATE_DEF_SEL(renderTargetArrayLength, + "renderTargetArrayLength"); +_MTL_PRIVATE_DEF_SEL(renderTargetHeight, + "renderTargetHeight"); +_MTL_PRIVATE_DEF_SEL(renderTargetWidth, + "renderTargetWidth"); +_MTL_PRIVATE_DEF_SEL(replaceRegion_mipmapLevel_slice_withBytes_bytesPerRow_bytesPerImage_, + "replaceRegion:mipmapLevel:slice:withBytes:bytesPerRow:bytesPerImage:"); +_MTL_PRIVATE_DEF_SEL(replaceRegion_mipmapLevel_withBytes_bytesPerRow_, + "replaceRegion:mipmapLevel:withBytes:bytesPerRow:"); +_MTL_PRIVATE_DEF_SEL(replaceSliceOrigin_sliceDimensions_withBytes_strides_, + "replaceSliceOrigin:sliceDimensions:withBytes:strides:"); +_MTL_PRIVATE_DEF_SEL(requestResidency, + "requestResidency"); +_MTL_PRIVATE_DEF_SEL(required, + "required"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerMeshThreadgroup, + "requiredThreadsPerMeshThreadgroup"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerObjectThreadgroup, + "requiredThreadsPerObjectThreadgroup"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerThreadgroup, + "requiredThreadsPerThreadgroup"); +_MTL_PRIVATE_DEF_SEL(requiredThreadsPerTileThreadgroup, + "requiredThreadsPerTileThreadgroup"); +_MTL_PRIVATE_DEF_SEL(reset, + "reset"); +_MTL_PRIVATE_DEF_SEL(resetCommandsInBuffer_withRange_, + "resetCommandsInBuffer:withRange:"); +_MTL_PRIVATE_DEF_SEL(resetTextureAccessCounters_region_mipLevel_slice_, + "resetTextureAccessCounters:region:mipLevel:slice:"); +_MTL_PRIVATE_DEF_SEL(resetWithRange_, + "resetWithRange:"); +_MTL_PRIVATE_DEF_SEL(resolveCounterHeap_withRange_intoBuffer_waitFence_updateFence_, + "resolveCounterHeap:withRange:intoBuffer:waitFence:updateFence:"); +_MTL_PRIVATE_DEF_SEL(resolveCounterRange_, + "resolveCounterRange:"); +_MTL_PRIVATE_DEF_SEL(resolveCounters_inRange_destinationBuffer_destinationOffset_, + "resolveCounters:inRange:destinationBuffer:destinationOffset:"); +_MTL_PRIVATE_DEF_SEL(resolveDepthPlane, + "resolveDepthPlane"); +_MTL_PRIVATE_DEF_SEL(resolveLevel, + "resolveLevel"); +_MTL_PRIVATE_DEF_SEL(resolveSlice, + "resolveSlice"); +_MTL_PRIVATE_DEF_SEL(resolveTexture, + "resolveTexture"); +_MTL_PRIVATE_DEF_SEL(resourceOptions, + "resourceOptions"); +_MTL_PRIVATE_DEF_SEL(resourceStateCommandEncoder, + "resourceStateCommandEncoder"); +_MTL_PRIVATE_DEF_SEL(resourceStateCommandEncoderWithDescriptor_, + "resourceStateCommandEncoderWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(resourceStatePassDescriptor, + "resourceStatePassDescriptor"); +_MTL_PRIVATE_DEF_SEL(resourceViewCount, + "resourceViewCount"); +_MTL_PRIVATE_DEF_SEL(retainedReferences, + "retainedReferences"); +_MTL_PRIVATE_DEF_SEL(rgbBlendOperation, + "rgbBlendOperation"); +_MTL_PRIVATE_DEF_SEL(rootResource, + "rootResource"); +_MTL_PRIVATE_DEF_SEL(sAddressMode, + "sAddressMode"); +_MTL_PRIVATE_DEF_SEL(sampleBuffer, + "sampleBuffer"); +_MTL_PRIVATE_DEF_SEL(sampleBufferAttachments, + "sampleBufferAttachments"); +_MTL_PRIVATE_DEF_SEL(sampleCount, + "sampleCount"); +_MTL_PRIVATE_DEF_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_, + "sampleCountersInBuffer:atSampleIndex:withBarrier:"); +_MTL_PRIVATE_DEF_SEL(sampleTimestamps_gpuTimestamp_, + "sampleTimestamps:gpuTimestamp:"); +_MTL_PRIVATE_DEF_SEL(scratchBufferAllocator, + "scratchBufferAllocator"); +_MTL_PRIVATE_DEF_SEL(screenSize, + "screenSize"); +_MTL_PRIVATE_DEF_SEL(segmentControlPointCount, + "segmentControlPointCount"); +_MTL_PRIVATE_DEF_SEL(segmentCount, + "segmentCount"); +_MTL_PRIVATE_DEF_SEL(serializeAsArchiveAndFlushToURL_error_, + "serializeAsArchiveAndFlushToURL:error:"); +_MTL_PRIVATE_DEF_SEL(serializeAsPipelinesScriptWithError_, + "serializeAsPipelinesScriptWithError:"); +_MTL_PRIVATE_DEF_SEL(serializeToURL_error_, + "serializeToURL:error:"); +_MTL_PRIVATE_DEF_SEL(setAccelerationStructure_atBufferIndex_, + "setAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setAccelerationStructure_atIndex_, + "setAccelerationStructure:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setAccess_, + "setAccess:"); +_MTL_PRIVATE_DEF_SEL(setAddress_atIndex_, + "setAddress:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setAddress_attributeStride_atIndex_, + "setAddress:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setAllowDuplicateIntersectionFunctionInvocation_, + "setAllowDuplicateIntersectionFunctionInvocation:"); +_MTL_PRIVATE_DEF_SEL(setAllowGPUOptimizedContents_, + "setAllowGPUOptimizedContents:"); +_MTL_PRIVATE_DEF_SEL(setAllowReferencingUndefinedSymbols_, + "setAllowReferencingUndefinedSymbols:"); +_MTL_PRIVATE_DEF_SEL(setAlphaBlendOperation_, + "setAlphaBlendOperation:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToCoverageEnabled_, + "setAlphaToCoverageEnabled:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToCoverageState_, + "setAlphaToCoverageState:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToOneEnabled_, + "setAlphaToOneEnabled:"); +_MTL_PRIVATE_DEF_SEL(setAlphaToOneState_, + "setAlphaToOneState:"); +_MTL_PRIVATE_DEF_SEL(setArgumentBuffer_offset_, + "setArgumentBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(setArgumentBuffer_startOffset_arrayElement_, + "setArgumentBuffer:startOffset:arrayElement:"); +_MTL_PRIVATE_DEF_SEL(setArgumentIndex_, + "setArgumentIndex:"); +_MTL_PRIVATE_DEF_SEL(setArgumentTable_, + "setArgumentTable:"); +_MTL_PRIVATE_DEF_SEL(setArgumentTable_atStages_, + "setArgumentTable:atStages:"); +_MTL_PRIVATE_DEF_SEL(setArguments_, + "setArguments:"); +_MTL_PRIVATE_DEF_SEL(setArrayLength_, + "setArrayLength:"); +_MTL_PRIVATE_DEF_SEL(setAttributes_, + "setAttributes:"); +_MTL_PRIVATE_DEF_SEL(setBackFaceStencil_, + "setBackFaceStencil:"); +_MTL_PRIVATE_DEF_SEL(setBarrier, + "setBarrier"); +_MTL_PRIVATE_DEF_SEL(setBinaryArchives_, + "setBinaryArchives:"); +_MTL_PRIVATE_DEF_SEL(setBinaryFunctions_, + "setBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setBinaryLinkedFunctions_, + "setBinaryLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setBlendColorRed_green_blue_alpha_, + "setBlendColorRed:green:blue:alpha:"); +_MTL_PRIVATE_DEF_SEL(setBlendingEnabled_, + "setBlendingEnabled:"); +_MTL_PRIVATE_DEF_SEL(setBlendingState_, + "setBlendingState:"); +_MTL_PRIVATE_DEF_SEL(setBorderColor_, + "setBorderColor:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxBuffer_, + "setBoundingBoxBuffer:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxBufferOffset_, + "setBoundingBoxBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxBuffers_, + "setBoundingBoxBuffers:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxCount_, + "setBoundingBoxCount:"); +_MTL_PRIVATE_DEF_SEL(setBoundingBoxStride_, + "setBoundingBoxStride:"); +_MTL_PRIVATE_DEF_SEL(setBuffer_, + "setBuffer:"); +_MTL_PRIVATE_DEF_SEL(setBuffer_offset_atIndex_, + "setBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBuffer_offset_attributeStride_atIndex_, + "setBuffer:offset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferIndex_, + "setBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferOffset_atIndex_, + "setBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferOffset_attributeStride_atIndex_, + "setBufferOffset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBufferSize_, + "setBufferSize:"); +_MTL_PRIVATE_DEF_SEL(setBuffers_offsets_attributeStrides_withRange_, + "setBuffers:offsets:attributeStrides:withRange:"); +_MTL_PRIVATE_DEF_SEL(setBuffers_offsets_withRange_, + "setBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setBytes_length_atIndex_, + "setBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setBytes_length_attributeStride_atIndex_, + "setBytes:length:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setCaptureObject_, + "setCaptureObject:"); +_MTL_PRIVATE_DEF_SEL(setClearColor_, + "setClearColor:"); +_MTL_PRIVATE_DEF_SEL(setClearDepth_, + "setClearDepth:"); +_MTL_PRIVATE_DEF_SEL(setClearStencil_, + "setClearStencil:"); +_MTL_PRIVATE_DEF_SEL(setColorAttachmentMap_, + "setColorAttachmentMap:"); +_MTL_PRIVATE_DEF_SEL(setColorAttachmentMappingState_, + "setColorAttachmentMappingState:"); +_MTL_PRIVATE_DEF_SEL(setColorStoreAction_atIndex_, + "setColorStoreAction:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setColorStoreActionOptions_atIndex_, + "setColorStoreActionOptions:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setCommandTypes_, + "setCommandTypes:"); +_MTL_PRIVATE_DEF_SEL(setCompareFunction_, + "setCompareFunction:"); +_MTL_PRIVATE_DEF_SEL(setCompileSymbolVisibility_, + "setCompileSymbolVisibility:"); +_MTL_PRIVATE_DEF_SEL(setCompressionType_, + "setCompressionType:"); +_MTL_PRIVATE_DEF_SEL(setComputeFunction_, + "setComputeFunction:"); +_MTL_PRIVATE_DEF_SEL(setComputeFunctionDescriptor_, + "setComputeFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setComputePipelineState_, + "setComputePipelineState:"); +_MTL_PRIVATE_DEF_SEL(setComputePipelineState_atIndex_, + "setComputePipelineState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setComputePipelineStates_withRange_, + "setComputePipelineStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setConfiguration_, + "setConfiguration:"); +_MTL_PRIVATE_DEF_SEL(setConstantBlockAlignment_, + "setConstantBlockAlignment:"); +_MTL_PRIVATE_DEF_SEL(setConstantValue_type_atIndex_, + "setConstantValue:type:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setConstantValue_type_withName_, + "setConstantValue:type:withName:"); +_MTL_PRIVATE_DEF_SEL(setConstantValues_, + "setConstantValues:"); +_MTL_PRIVATE_DEF_SEL(setConstantValues_type_withRange_, + "setConstantValues:type:withRange:"); +_MTL_PRIVATE_DEF_SEL(setControlDependencies_, + "setControlDependencies:"); +_MTL_PRIVATE_DEF_SEL(setControlPointBuffer_, + "setControlPointBuffer:"); +_MTL_PRIVATE_DEF_SEL(setControlPointBufferOffset_, + "setControlPointBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setControlPointBuffers_, + "setControlPointBuffers:"); +_MTL_PRIVATE_DEF_SEL(setControlPointCount_, + "setControlPointCount:"); +_MTL_PRIVATE_DEF_SEL(setControlPointFormat_, + "setControlPointFormat:"); +_MTL_PRIVATE_DEF_SEL(setControlPointStride_, + "setControlPointStride:"); +_MTL_PRIVATE_DEF_SEL(setCount_, + "setCount:"); +_MTL_PRIVATE_DEF_SEL(setCounterSet_, + "setCounterSet:"); +_MTL_PRIVATE_DEF_SEL(setCpuCacheMode_, + "setCpuCacheMode:"); +_MTL_PRIVATE_DEF_SEL(setCullMode_, + "setCullMode:"); +_MTL_PRIVATE_DEF_SEL(setCurveBasis_, + "setCurveBasis:"); +_MTL_PRIVATE_DEF_SEL(setCurveEndCaps_, + "setCurveEndCaps:"); +_MTL_PRIVATE_DEF_SEL(setCurveType_, + "setCurveType:"); +_MTL_PRIVATE_DEF_SEL(setDataType_, + "setDataType:"); +_MTL_PRIVATE_DEF_SEL(setDefaultCaptureScope_, + "setDefaultCaptureScope:"); +_MTL_PRIVATE_DEF_SEL(setDefaultRasterSampleCount_, + "setDefaultRasterSampleCount:"); +_MTL_PRIVATE_DEF_SEL(setDepth_, + "setDepth:"); +_MTL_PRIVATE_DEF_SEL(setDepthAttachment_, + "setDepthAttachment:"); +_MTL_PRIVATE_DEF_SEL(setDepthAttachmentPixelFormat_, + "setDepthAttachmentPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(setDepthBias_slopeScale_clamp_, + "setDepthBias:slopeScale:clamp:"); +_MTL_PRIVATE_DEF_SEL(setDepthClipMode_, + "setDepthClipMode:"); +_MTL_PRIVATE_DEF_SEL(setDepthCompareFunction_, + "setDepthCompareFunction:"); +_MTL_PRIVATE_DEF_SEL(setDepthFailureOperation_, + "setDepthFailureOperation:"); +_MTL_PRIVATE_DEF_SEL(setDepthPlane_, + "setDepthPlane:"); +_MTL_PRIVATE_DEF_SEL(setDepthResolveFilter_, + "setDepthResolveFilter:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilPassOperation_, + "setDepthStencilPassOperation:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilState_, + "setDepthStencilState:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilState_atIndex_, + "setDepthStencilState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setDepthStencilStates_withRange_, + "setDepthStencilStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setDepthStoreAction_, + "setDepthStoreAction:"); +_MTL_PRIVATE_DEF_SEL(setDepthStoreActionOptions_, + "setDepthStoreActionOptions:"); +_MTL_PRIVATE_DEF_SEL(setDepthTestMinBound_maxBound_, + "setDepthTestMinBound:maxBound:"); +_MTL_PRIVATE_DEF_SEL(setDepthWriteEnabled_, + "setDepthWriteEnabled:"); +_MTL_PRIVATE_DEF_SEL(setDestination_, + "setDestination:"); +_MTL_PRIVATE_DEF_SEL(setDestinationAlphaBlendFactor_, + "setDestinationAlphaBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setDestinationRGBBlendFactor_, + "setDestinationRGBBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setDimensions_, + "setDimensions:"); +_MTL_PRIVATE_DEF_SEL(setDispatchType_, + "setDispatchType:"); +_MTL_PRIVATE_DEF_SEL(setEnableLogging_, + "setEnableLogging:"); +_MTL_PRIVATE_DEF_SEL(setEndOfEncoderSampleIndex_, + "setEndOfEncoderSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setEndOfFragmentSampleIndex_, + "setEndOfFragmentSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setEndOfVertexSampleIndex_, + "setEndOfVertexSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setErrorOptions_, + "setErrorOptions:"); +_MTL_PRIVATE_DEF_SEL(setFastMathEnabled_, + "setFastMathEnabled:"); +_MTL_PRIVATE_DEF_SEL(setFeedbackQueue_, + "setFeedbackQueue:"); +_MTL_PRIVATE_DEF_SEL(setFormat_, + "setFormat:"); +_MTL_PRIVATE_DEF_SEL(setFragmentAccelerationStructure_atBufferIndex_, + "setFragmentAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentAdditionalBinaryFunctions_, + "setFragmentAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBuffer_offset_atIndex_, + "setFragmentBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBufferOffset_atIndex_, + "setFragmentBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBuffers_offsets_withRange_, + "setFragmentBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentBytes_length_atIndex_, + "setFragmentBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentFunction_, + "setFragmentFunction:"); +_MTL_PRIVATE_DEF_SEL(setFragmentFunctionDescriptor_, + "setFragmentFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setFragmentIntersectionFunctionTable_atBufferIndex_, + "setFragmentIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentIntersectionFunctionTables_withBufferRange_, + "setFragmentIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentLinkedFunctions_, + "setFragmentLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setFragmentPreloadedLibraries_, + "setFragmentPreloadedLibraries:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerState_atIndex_, + "setFragmentSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setFragmentSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setFragmentSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentSamplerStates_withRange_, + "setFragmentSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentStaticLinkingDescriptor_, + "setFragmentStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setFragmentTexture_atIndex_, + "setFragmentTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentTextures_withRange_, + "setFragmentTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setFragmentVisibleFunctionTable_atBufferIndex_, + "setFragmentVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setFragmentVisibleFunctionTables_withBufferRange_, + "setFragmentVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setFrontFaceStencil_, + "setFrontFaceStencil:"); +_MTL_PRIVATE_DEF_SEL(setFrontFacingWinding_, + "setFrontFacingWinding:"); +_MTL_PRIVATE_DEF_SEL(setFunction_atIndex_, + "setFunction:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setFunctionCount_, + "setFunctionCount:"); +_MTL_PRIVATE_DEF_SEL(setFunctionDescriptor_, + "setFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setFunctionDescriptors_, + "setFunctionDescriptors:"); +_MTL_PRIVATE_DEF_SEL(setFunctionGraph_, + "setFunctionGraph:"); +_MTL_PRIVATE_DEF_SEL(setFunctionGraphs_, + "setFunctionGraphs:"); +_MTL_PRIVATE_DEF_SEL(setFunctionName_, + "setFunctionName:"); +_MTL_PRIVATE_DEF_SEL(setFunctions_, + "setFunctions:"); +_MTL_PRIVATE_DEF_SEL(setFunctions_withRange_, + "setFunctions:withRange:"); +_MTL_PRIVATE_DEF_SEL(setGeometryDescriptors_, + "setGeometryDescriptors:"); +_MTL_PRIVATE_DEF_SEL(setGroups_, + "setGroups:"); +_MTL_PRIVATE_DEF_SEL(setHazardTrackingMode_, + "setHazardTrackingMode:"); +_MTL_PRIVATE_DEF_SEL(setHeight_, + "setHeight:"); +_MTL_PRIVATE_DEF_SEL(setImageblockSampleLength_, + "setImageblockSampleLength:"); +_MTL_PRIVATE_DEF_SEL(setImageblockWidth_height_, + "setImageblockWidth:height:"); +_MTL_PRIVATE_DEF_SEL(setIndex_, + "setIndex:"); +_MTL_PRIVATE_DEF_SEL(setIndexBuffer_, + "setIndexBuffer:"); +_MTL_PRIVATE_DEF_SEL(setIndexBufferIndex_, + "setIndexBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setIndexBufferOffset_, + "setIndexBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setIndexType_, + "setIndexType:"); +_MTL_PRIVATE_DEF_SEL(setIndirectCommandBuffer_atIndex_, + "setIndirectCommandBuffer:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setIndirectCommandBuffers_withRange_, + "setIndirectCommandBuffers:withRange:"); +_MTL_PRIVATE_DEF_SEL(setInheritBuffers_, + "setInheritBuffers:"); +_MTL_PRIVATE_DEF_SEL(setInheritCullMode_, + "setInheritCullMode:"); +_MTL_PRIVATE_DEF_SEL(setInheritDepthBias_, + "setInheritDepthBias:"); +_MTL_PRIVATE_DEF_SEL(setInheritDepthClipMode_, + "setInheritDepthClipMode:"); +_MTL_PRIVATE_DEF_SEL(setInheritDepthStencilState_, + "setInheritDepthStencilState:"); +_MTL_PRIVATE_DEF_SEL(setInheritFrontFacingWinding_, + "setInheritFrontFacingWinding:"); +_MTL_PRIVATE_DEF_SEL(setInheritPipelineState_, + "setInheritPipelineState:"); +_MTL_PRIVATE_DEF_SEL(setInheritTriangleFillMode_, + "setInheritTriangleFillMode:"); +_MTL_PRIVATE_DEF_SEL(setInitialCapacity_, + "setInitialCapacity:"); +_MTL_PRIVATE_DEF_SEL(setInitializeBindings_, + "setInitializeBindings:"); +_MTL_PRIVATE_DEF_SEL(setInputDimensions_atBufferIndex_, + "setInputDimensions:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setInputDimensions_withRange_, + "setInputDimensions:withRange:"); +_MTL_PRIVATE_DEF_SEL(setInputPrimitiveTopology_, + "setInputPrimitiveTopology:"); +_MTL_PRIVATE_DEF_SEL(setInsertLibraries_, + "setInsertLibraries:"); +_MTL_PRIVATE_DEF_SEL(setInstallName_, + "setInstallName:"); +_MTL_PRIVATE_DEF_SEL(setInstanceCount_, + "setInstanceCount:"); +_MTL_PRIVATE_DEF_SEL(setInstanceCountBuffer_, + "setInstanceCountBuffer:"); +_MTL_PRIVATE_DEF_SEL(setInstanceCountBufferOffset_, + "setInstanceCountBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorBuffer_, + "setInstanceDescriptorBuffer:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorBufferOffset_, + "setInstanceDescriptorBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorStride_, + "setInstanceDescriptorStride:"); +_MTL_PRIVATE_DEF_SEL(setInstanceDescriptorType_, + "setInstanceDescriptorType:"); +_MTL_PRIVATE_DEF_SEL(setInstanceTransformationMatrixLayout_, + "setInstanceTransformationMatrixLayout:"); +_MTL_PRIVATE_DEF_SEL(setInstancedAccelerationStructures_, + "setInstancedAccelerationStructures:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTable_atBufferIndex_, + "setIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTable_atIndex_, + "setIntersectionFunctionTable:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTableOffset_, + "setIntersectionFunctionTableOffset:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTables_withBufferRange_, + "setIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setIntersectionFunctionTables_withRange_, + "setIntersectionFunctionTables:withRange:"); +_MTL_PRIVATE_DEF_SEL(setKernelBuffer_offset_atIndex_, + "setKernelBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setKernelBuffer_offset_attributeStride_atIndex_, + "setKernelBuffer:offset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setLabel_, + "setLabel:"); +_MTL_PRIVATE_DEF_SEL(setLanguageVersion_, + "setLanguageVersion:"); +_MTL_PRIVATE_DEF_SEL(setLayer_atIndex_, + "setLayer:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setLevel_, + "setLevel:"); +_MTL_PRIVATE_DEF_SEL(setLevelRange_, + "setLevelRange:"); +_MTL_PRIVATE_DEF_SEL(setLibraries_, + "setLibraries:"); +_MTL_PRIVATE_DEF_SEL(setLibrary_, + "setLibrary:"); +_MTL_PRIVATE_DEF_SEL(setLibraryType_, + "setLibraryType:"); +_MTL_PRIVATE_DEF_SEL(setLinkedFunctions_, + "setLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setLoadAction_, + "setLoadAction:"); +_MTL_PRIVATE_DEF_SEL(setLodAverage_, + "setLodAverage:"); +_MTL_PRIVATE_DEF_SEL(setLodBias_, + "setLodBias:"); +_MTL_PRIVATE_DEF_SEL(setLodMaxClamp_, + "setLodMaxClamp:"); +_MTL_PRIVATE_DEF_SEL(setLodMinClamp_, + "setLodMinClamp:"); +_MTL_PRIVATE_DEF_SEL(setLogState_, + "setLogState:"); +_MTL_PRIVATE_DEF_SEL(setLookupArchives_, + "setLookupArchives:"); +_MTL_PRIVATE_DEF_SEL(setMachineLearningFunctionDescriptor_, + "setMachineLearningFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setMagFilter_, + "setMagFilter:"); +_MTL_PRIVATE_DEF_SEL(setMathFloatingPointFunctions_, + "setMathFloatingPointFunctions:"); +_MTL_PRIVATE_DEF_SEL(setMathMode_, + "setMathMode:"); +_MTL_PRIVATE_DEF_SEL(setMaxAnisotropy_, + "setMaxAnisotropy:"); +_MTL_PRIVATE_DEF_SEL(setMaxBufferBindCount_, + "setMaxBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxCallStackDepth_, + "setMaxCallStackDepth:"); +_MTL_PRIVATE_DEF_SEL(setMaxCommandBufferCount_, + "setMaxCommandBufferCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxCommandsInFlight_, + "setMaxCommandsInFlight:"); +_MTL_PRIVATE_DEF_SEL(setMaxCompatiblePlacementSparsePageSize_, + "setMaxCompatiblePlacementSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(setMaxFragmentBufferBindCount_, + "setMaxFragmentBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxFragmentCallStackDepth_, + "setMaxFragmentCallStackDepth:"); +_MTL_PRIVATE_DEF_SEL(setMaxInstanceCount_, + "setMaxInstanceCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxKernelBufferBindCount_, + "setMaxKernelBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxKernelThreadgroupMemoryBindCount_, + "setMaxKernelThreadgroupMemoryBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxMeshBufferBindCount_, + "setMaxMeshBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxMotionTransformCount_, + "setMaxMotionTransformCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxObjectBufferBindCount_, + "setMaxObjectBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxObjectThreadgroupMemoryBindCount_, + "setMaxObjectThreadgroupMemoryBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxSamplerStateBindCount_, + "setMaxSamplerStateBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxTessellationFactor_, + "setMaxTessellationFactor:"); +_MTL_PRIVATE_DEF_SEL(setMaxTextureBindCount_, + "setMaxTextureBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadgroupsPerMeshGrid_, + "setMaxTotalThreadgroupsPerMeshGrid:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadsPerMeshThreadgroup_, + "setMaxTotalThreadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadsPerObjectThreadgroup_, + "setMaxTotalThreadsPerObjectThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setMaxTotalThreadsPerThreadgroup_, + "setMaxTotalThreadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setMaxVertexAmplificationCount_, + "setMaxVertexAmplificationCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxVertexBufferBindCount_, + "setMaxVertexBufferBindCount:"); +_MTL_PRIVATE_DEF_SEL(setMaxVertexCallStackDepth_, + "setMaxVertexCallStackDepth:"); +_MTL_PRIVATE_DEF_SEL(setMeshAdditionalBinaryFunctions_, + "setMeshAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setMeshBuffer_offset_atIndex_, + "setMeshBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshBufferOffset_atIndex_, + "setMeshBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshBuffers_offsets_withRange_, + "setMeshBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshBytes_length_atIndex_, + "setMeshBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshFunction_, + "setMeshFunction:"); +_MTL_PRIVATE_DEF_SEL(setMeshFunctionDescriptor_, + "setMeshFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setMeshLinkedFunctions_, + "setMeshLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerState_atIndex_, + "setMeshSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setMeshSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setMeshSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshSamplerStates_withRange_, + "setMeshSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshStaticLinkingDescriptor_, + "setMeshStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setMeshTexture_atIndex_, + "setMeshTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setMeshTextures_withRange_, + "setMeshTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth_, + "setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth:"); +_MTL_PRIVATE_DEF_SEL(setMinFilter_, + "setMinFilter:"); +_MTL_PRIVATE_DEF_SEL(setMipFilter_, + "setMipFilter:"); +_MTL_PRIVATE_DEF_SEL(setMipmapLevelCount_, + "setMipmapLevelCount:"); +_MTL_PRIVATE_DEF_SEL(setMotionEndBorderMode_, + "setMotionEndBorderMode:"); +_MTL_PRIVATE_DEF_SEL(setMotionEndTime_, + "setMotionEndTime:"); +_MTL_PRIVATE_DEF_SEL(setMotionKeyframeCount_, + "setMotionKeyframeCount:"); +_MTL_PRIVATE_DEF_SEL(setMotionStartBorderMode_, + "setMotionStartBorderMode:"); +_MTL_PRIVATE_DEF_SEL(setMotionStartTime_, + "setMotionStartTime:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformBuffer_, + "setMotionTransformBuffer:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformBufferOffset_, + "setMotionTransformBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformCount_, + "setMotionTransformCount:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformCountBuffer_, + "setMotionTransformCountBuffer:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformCountBufferOffset_, + "setMotionTransformCountBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformStride_, + "setMotionTransformStride:"); +_MTL_PRIVATE_DEF_SEL(setMotionTransformType_, + "setMotionTransformType:"); +_MTL_PRIVATE_DEF_SEL(setMutability_, + "setMutability:"); +_MTL_PRIVATE_DEF_SEL(setName_, + "setName:"); +_MTL_PRIVATE_DEF_SEL(setNodes_, + "setNodes:"); +_MTL_PRIVATE_DEF_SEL(setNormalizedCoordinates_, + "setNormalizedCoordinates:"); +_MTL_PRIVATE_DEF_SEL(setObject_atIndexedSubscript_, + "setObject:atIndexedSubscript:"); +_MTL_PRIVATE_DEF_SEL(setObjectAdditionalBinaryFunctions_, + "setObjectAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setObjectBuffer_offset_atIndex_, + "setObjectBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectBufferOffset_atIndex_, + "setObjectBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectBuffers_offsets_withRange_, + "setObjectBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectBytes_length_atIndex_, + "setObjectBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectFunction_, + "setObjectFunction:"); +_MTL_PRIVATE_DEF_SEL(setObjectFunctionDescriptor_, + "setObjectFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setObjectLinkedFunctions_, + "setObjectLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerState_atIndex_, + "setObjectSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setObjectSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setObjectSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectSamplerStates_withRange_, + "setObjectSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectStaticLinkingDescriptor_, + "setObjectStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setObjectTexture_atIndex_, + "setObjectTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectTextures_withRange_, + "setObjectTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setObjectThreadgroupMemoryLength_atIndex_, + "setObjectThreadgroupMemoryLength:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth_, + "setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth:"); +_MTL_PRIVATE_DEF_SEL(setOffset_, + "setOffset:"); +_MTL_PRIVATE_DEF_SEL(setOpaque_, + "setOpaque:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueCurveIntersectionFunctionWithSignature_atIndex_, + "setOpaqueCurveIntersectionFunctionWithSignature:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueCurveIntersectionFunctionWithSignature_withRange_, + "setOpaqueCurveIntersectionFunctionWithSignature:withRange:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_atIndex_, + "setOpaqueTriangleIntersectionFunctionWithSignature:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_withRange_, + "setOpaqueTriangleIntersectionFunctionWithSignature:withRange:"); +_MTL_PRIVATE_DEF_SEL(setOptimizationLevel_, + "setOptimizationLevel:"); +_MTL_PRIVATE_DEF_SEL(setOptions_, + "setOptions:"); +_MTL_PRIVATE_DEF_SEL(setOutputNode_, + "setOutputNode:"); +_MTL_PRIVATE_DEF_SEL(setOutputURL_, + "setOutputURL:"); +_MTL_PRIVATE_DEF_SEL(setOwnerWithIdentity_, + "setOwnerWithIdentity:"); +_MTL_PRIVATE_DEF_SEL(setPayloadMemoryLength_, + "setPayloadMemoryLength:"); +_MTL_PRIVATE_DEF_SEL(setPhysicalIndex_forLogicalIndex_, + "setPhysicalIndex:forLogicalIndex:"); +_MTL_PRIVATE_DEF_SEL(setPipelineDataSetSerializer_, + "setPipelineDataSetSerializer:"); +_MTL_PRIVATE_DEF_SEL(setPipelineState_, + "setPipelineState:"); +_MTL_PRIVATE_DEF_SEL(setPixelFormat_, + "setPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(setPlacementSparsePageSize_, + "setPlacementSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(setPreloadedLibraries_, + "setPreloadedLibraries:"); +_MTL_PRIVATE_DEF_SEL(setPreprocessorMacros_, + "setPreprocessorMacros:"); +_MTL_PRIVATE_DEF_SEL(setPreserveInvariance_, + "setPreserveInvariance:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataBuffer_, + "setPrimitiveDataBuffer:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataBufferOffset_, + "setPrimitiveDataBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataElementSize_, + "setPrimitiveDataElementSize:"); +_MTL_PRIVATE_DEF_SEL(setPrimitiveDataStride_, + "setPrimitiveDataStride:"); +_MTL_PRIVATE_DEF_SEL(setPriority_, + "setPriority:"); +_MTL_PRIVATE_DEF_SEL(setPrivateFunctionDescriptors_, + "setPrivateFunctionDescriptors:"); +_MTL_PRIVATE_DEF_SEL(setPrivateFunctions_, + "setPrivateFunctions:"); +_MTL_PRIVATE_DEF_SEL(setPurgeableState_, + "setPurgeableState:"); +_MTL_PRIVATE_DEF_SEL(setRAddressMode_, + "setRAddressMode:"); +_MTL_PRIVATE_DEF_SEL(setRadiusBuffer_, + "setRadiusBuffer:"); +_MTL_PRIVATE_DEF_SEL(setRadiusBufferOffset_, + "setRadiusBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setRadiusBuffers_, + "setRadiusBuffers:"); +_MTL_PRIVATE_DEF_SEL(setRadiusFormat_, + "setRadiusFormat:"); +_MTL_PRIVATE_DEF_SEL(setRadiusStride_, + "setRadiusStride:"); +_MTL_PRIVATE_DEF_SEL(setRasterSampleCount_, + "setRasterSampleCount:"); +_MTL_PRIVATE_DEF_SEL(setRasterizationEnabled_, + "setRasterizationEnabled:"); +_MTL_PRIVATE_DEF_SEL(setRasterizationRateMap_, + "setRasterizationRateMap:"); +_MTL_PRIVATE_DEF_SEL(setReadMask_, + "setReadMask:"); +_MTL_PRIVATE_DEF_SEL(setReductionMode_, + "setReductionMode:"); +_MTL_PRIVATE_DEF_SEL(setRenderPipelineState_, + "setRenderPipelineState:"); +_MTL_PRIVATE_DEF_SEL(setRenderPipelineState_atIndex_, + "setRenderPipelineState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setRenderPipelineStates_withRange_, + "setRenderPipelineStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setRenderTargetArrayLength_, + "setRenderTargetArrayLength:"); +_MTL_PRIVATE_DEF_SEL(setRenderTargetHeight_, + "setRenderTargetHeight:"); +_MTL_PRIVATE_DEF_SEL(setRenderTargetWidth_, + "setRenderTargetWidth:"); +_MTL_PRIVATE_DEF_SEL(setRequiredThreadsPerMeshThreadgroup_, + "setRequiredThreadsPerMeshThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setRequiredThreadsPerObjectThreadgroup_, + "setRequiredThreadsPerObjectThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setRequiredThreadsPerThreadgroup_, + "setRequiredThreadsPerThreadgroup:"); +_MTL_PRIVATE_DEF_SEL(setResolveDepthPlane_, + "setResolveDepthPlane:"); +_MTL_PRIVATE_DEF_SEL(setResolveLevel_, + "setResolveLevel:"); +_MTL_PRIVATE_DEF_SEL(setResolveSlice_, + "setResolveSlice:"); +_MTL_PRIVATE_DEF_SEL(setResolveTexture_, + "setResolveTexture:"); +_MTL_PRIVATE_DEF_SEL(setResource_atBufferIndex_, + "setResource:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setResourceOptions_, + "setResourceOptions:"); +_MTL_PRIVATE_DEF_SEL(setResourceViewCount_, + "setResourceViewCount:"); +_MTL_PRIVATE_DEF_SEL(setRetainedReferences_, + "setRetainedReferences:"); +_MTL_PRIVATE_DEF_SEL(setRgbBlendOperation_, + "setRgbBlendOperation:"); +_MTL_PRIVATE_DEF_SEL(setSAddressMode_, + "setSAddressMode:"); +_MTL_PRIVATE_DEF_SEL(setSampleBuffer_, + "setSampleBuffer:"); +_MTL_PRIVATE_DEF_SEL(setSampleCount_, + "setSampleCount:"); +_MTL_PRIVATE_DEF_SEL(setSamplePositions_count_, + "setSamplePositions:count:"); +_MTL_PRIVATE_DEF_SEL(setSamplerState_atIndex_, + "setSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setSamplerStates_withRange_, + "setSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setScissorRect_, + "setScissorRect:"); +_MTL_PRIVATE_DEF_SEL(setScissorRects_count_, + "setScissorRects:count:"); +_MTL_PRIVATE_DEF_SEL(setScratchBufferAllocator_, + "setScratchBufferAllocator:"); +_MTL_PRIVATE_DEF_SEL(setScreenSize_, + "setScreenSize:"); +_MTL_PRIVATE_DEF_SEL(setSegmentControlPointCount_, + "setSegmentControlPointCount:"); +_MTL_PRIVATE_DEF_SEL(setSegmentCount_, + "setSegmentCount:"); +_MTL_PRIVATE_DEF_SEL(setShaderReflection_, + "setShaderReflection:"); +_MTL_PRIVATE_DEF_SEL(setShaderValidation_, + "setShaderValidation:"); +_MTL_PRIVATE_DEF_SEL(setShouldMaximizeConcurrentCompilation_, + "setShouldMaximizeConcurrentCompilation:"); +_MTL_PRIVATE_DEF_SEL(setSignaledValue_, + "setSignaledValue:"); +_MTL_PRIVATE_DEF_SEL(setSize_, + "setSize:"); +_MTL_PRIVATE_DEF_SEL(setSlice_, + "setSlice:"); +_MTL_PRIVATE_DEF_SEL(setSliceRange_, + "setSliceRange:"); +_MTL_PRIVATE_DEF_SEL(setSource_, + "setSource:"); +_MTL_PRIVATE_DEF_SEL(setSourceAlphaBlendFactor_, + "setSourceAlphaBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setSourceRGBBlendFactor_, + "setSourceRGBBlendFactor:"); +_MTL_PRIVATE_DEF_SEL(setSparsePageSize_, + "setSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(setSpecializedName_, + "setSpecializedName:"); +_MTL_PRIVATE_DEF_SEL(setStageInRegion_, + "setStageInRegion:"); +_MTL_PRIVATE_DEF_SEL(setStageInRegionWithIndirectBuffer_indirectBufferOffset_, + "setStageInRegionWithIndirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setStageInputDescriptor_, + "setStageInputDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setStartOfEncoderSampleIndex_, + "setStartOfEncoderSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setStartOfFragmentSampleIndex_, + "setStartOfFragmentSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setStartOfVertexSampleIndex_, + "setStartOfVertexSampleIndex:"); +_MTL_PRIVATE_DEF_SEL(setStaticLinkingDescriptor_, + "setStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setStencilAttachment_, + "setStencilAttachment:"); +_MTL_PRIVATE_DEF_SEL(setStencilAttachmentPixelFormat_, + "setStencilAttachmentPixelFormat:"); +_MTL_PRIVATE_DEF_SEL(setStencilCompareFunction_, + "setStencilCompareFunction:"); +_MTL_PRIVATE_DEF_SEL(setStencilFailureOperation_, + "setStencilFailureOperation:"); +_MTL_PRIVATE_DEF_SEL(setStencilFrontReferenceValue_backReferenceValue_, + "setStencilFrontReferenceValue:backReferenceValue:"); +_MTL_PRIVATE_DEF_SEL(setStencilReferenceValue_, + "setStencilReferenceValue:"); +_MTL_PRIVATE_DEF_SEL(setStencilResolveFilter_, + "setStencilResolveFilter:"); +_MTL_PRIVATE_DEF_SEL(setStencilStoreAction_, + "setStencilStoreAction:"); +_MTL_PRIVATE_DEF_SEL(setStencilStoreActionOptions_, + "setStencilStoreActionOptions:"); +_MTL_PRIVATE_DEF_SEL(setStepFunction_, + "setStepFunction:"); +_MTL_PRIVATE_DEF_SEL(setStepRate_, + "setStepRate:"); +_MTL_PRIVATE_DEF_SEL(setStorageMode_, + "setStorageMode:"); +_MTL_PRIVATE_DEF_SEL(setStoreAction_, + "setStoreAction:"); +_MTL_PRIVATE_DEF_SEL(setStoreActionOptions_, + "setStoreActionOptions:"); +_MTL_PRIVATE_DEF_SEL(setStride_, + "setStride:"); +_MTL_PRIVATE_DEF_SEL(setStrides_, + "setStrides:"); +_MTL_PRIVATE_DEF_SEL(setSupportAddingBinaryFunctions_, + "setSupportAddingBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setSupportAddingFragmentBinaryFunctions_, + "setSupportAddingFragmentBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setSupportAddingVertexBinaryFunctions_, + "setSupportAddingVertexBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setSupportArgumentBuffers_, + "setSupportArgumentBuffers:"); +_MTL_PRIVATE_DEF_SEL(setSupportAttributeStrides_, + "setSupportAttributeStrides:"); +_MTL_PRIVATE_DEF_SEL(setSupportBinaryLinking_, + "setSupportBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportColorAttachmentMapping_, + "setSupportColorAttachmentMapping:"); +_MTL_PRIVATE_DEF_SEL(setSupportDynamicAttributeStride_, + "setSupportDynamicAttributeStride:"); +_MTL_PRIVATE_DEF_SEL(setSupportFragmentBinaryLinking_, + "setSupportFragmentBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportIndirectCommandBuffers_, + "setSupportIndirectCommandBuffers:"); +_MTL_PRIVATE_DEF_SEL(setSupportMeshBinaryLinking_, + "setSupportMeshBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportObjectBinaryLinking_, + "setSupportObjectBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSupportRayTracing_, + "setSupportRayTracing:"); +_MTL_PRIVATE_DEF_SEL(setSupportVertexBinaryLinking_, + "setSupportVertexBinaryLinking:"); +_MTL_PRIVATE_DEF_SEL(setSwizzle_, + "setSwizzle:"); +_MTL_PRIVATE_DEF_SEL(setTAddressMode_, + "setTAddressMode:"); +_MTL_PRIVATE_DEF_SEL(setTessellationControlPointIndexType_, + "setTessellationControlPointIndexType:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorBuffer_offset_instanceStride_, + "setTessellationFactorBuffer:offset:instanceStride:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorFormat_, + "setTessellationFactorFormat:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorScale_, + "setTessellationFactorScale:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorScaleEnabled_, + "setTessellationFactorScaleEnabled:"); +_MTL_PRIVATE_DEF_SEL(setTessellationFactorStepFunction_, + "setTessellationFactorStepFunction:"); +_MTL_PRIVATE_DEF_SEL(setTessellationOutputWindingOrder_, + "setTessellationOutputWindingOrder:"); +_MTL_PRIVATE_DEF_SEL(setTessellationPartitionMode_, + "setTessellationPartitionMode:"); +_MTL_PRIVATE_DEF_SEL(setTexture_, + "setTexture:"); +_MTL_PRIVATE_DEF_SEL(setTexture_atIndex_, + "setTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextureType_, + "setTextureType:"); +_MTL_PRIVATE_DEF_SEL(setTextureView_atIndex_, + "setTextureView:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextureView_descriptor_atIndex_, + "setTextureView:descriptor:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextureViewFromBuffer_descriptor_offset_bytesPerRow_atIndex_, + "setTextureViewFromBuffer:descriptor:offset:bytesPerRow:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTextures_withRange_, + "setTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setThreadGroupSizeIsMultipleOfThreadExecutionWidth_, + "setThreadGroupSizeIsMultipleOfThreadExecutionWidth:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupMemoryLength_, + "setThreadgroupMemoryLength:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupMemoryLength_atIndex_, + "setThreadgroupMemoryLength:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupMemoryLength_offset_atIndex_, + "setThreadgroupMemoryLength:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setThreadgroupSizeMatchesTileSize_, + "setThreadgroupSizeMatchesTileSize:"); +_MTL_PRIVATE_DEF_SEL(setTileAccelerationStructure_atBufferIndex_, + "setTileAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileAdditionalBinaryFunctions_, + "setTileAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setTileBuffer_offset_atIndex_, + "setTileBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileBufferOffset_atIndex_, + "setTileBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileBuffers_offsets_withRange_, + "setTileBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileBytes_length_atIndex_, + "setTileBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileFunction_, + "setTileFunction:"); +_MTL_PRIVATE_DEF_SEL(setTileFunctionDescriptor_, + "setTileFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setTileHeight_, + "setTileHeight:"); +_MTL_PRIVATE_DEF_SEL(setTileIntersectionFunctionTable_atBufferIndex_, + "setTileIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileIntersectionFunctionTables_withBufferRange_, + "setTileIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerState_atIndex_, + "setTileSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setTileSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setTileSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileSamplerStates_withRange_, + "setTileSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileTexture_atIndex_, + "setTileTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileTextures_withRange_, + "setTileTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setTileVisibleFunctionTable_atBufferIndex_, + "setTileVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setTileVisibleFunctionTables_withBufferRange_, + "setTileVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setTileWidth_, + "setTileWidth:"); +_MTL_PRIVATE_DEF_SEL(setTransformationMatrixBuffer_, + "setTransformationMatrixBuffer:"); +_MTL_PRIVATE_DEF_SEL(setTransformationMatrixBufferOffset_, + "setTransformationMatrixBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setTransformationMatrixLayout_, + "setTransformationMatrixLayout:"); +_MTL_PRIVATE_DEF_SEL(setTriangleCount_, + "setTriangleCount:"); +_MTL_PRIVATE_DEF_SEL(setTriangleFillMode_, + "setTriangleFillMode:"); +_MTL_PRIVATE_DEF_SEL(setType_, + "setType:"); +_MTL_PRIVATE_DEF_SEL(setUrl_, + "setUrl:"); +_MTL_PRIVATE_DEF_SEL(setUsage_, + "setUsage:"); +_MTL_PRIVATE_DEF_SEL(setVertexAccelerationStructure_atBufferIndex_, + "setVertexAccelerationStructure:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexAdditionalBinaryFunctions_, + "setVertexAdditionalBinaryFunctions:"); +_MTL_PRIVATE_DEF_SEL(setVertexAmplificationCount_viewMappings_, + "setVertexAmplificationCount:viewMappings:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffer_, + "setVertexBuffer:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffer_offset_atIndex_, + "setVertexBuffer:offset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffer_offset_attributeStride_atIndex_, + "setVertexBuffer:offset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBufferOffset_, + "setVertexBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(setVertexBufferOffset_atIndex_, + "setVertexBufferOffset:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBufferOffset_attributeStride_atIndex_, + "setVertexBufferOffset:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffers_, + "setVertexBuffers:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffers_offsets_attributeStrides_withRange_, + "setVertexBuffers:offsets:attributeStrides:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexBuffers_offsets_withRange_, + "setVertexBuffers:offsets:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexBytes_length_atIndex_, + "setVertexBytes:length:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexBytes_length_attributeStride_atIndex_, + "setVertexBytes:length:attributeStride:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexDescriptor_, + "setVertexDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setVertexFormat_, + "setVertexFormat:"); +_MTL_PRIVATE_DEF_SEL(setVertexFunction_, + "setVertexFunction:"); +_MTL_PRIVATE_DEF_SEL(setVertexFunctionDescriptor_, + "setVertexFunctionDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setVertexIntersectionFunctionTable_atBufferIndex_, + "setVertexIntersectionFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexIntersectionFunctionTables_withBufferRange_, + "setVertexIntersectionFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexLinkedFunctions_, + "setVertexLinkedFunctions:"); +_MTL_PRIVATE_DEF_SEL(setVertexPreloadedLibraries_, + "setVertexPreloadedLibraries:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerState_atIndex_, + "setVertexSamplerState:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerState_lodMinClamp_lodMaxClamp_atIndex_, + "setVertexSamplerState:lodMinClamp:lodMaxClamp:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerStates_lodMinClamps_lodMaxClamps_withRange_, + "setVertexSamplerStates:lodMinClamps:lodMaxClamps:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexSamplerStates_withRange_, + "setVertexSamplerStates:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexStaticLinkingDescriptor_, + "setVertexStaticLinkingDescriptor:"); +_MTL_PRIVATE_DEF_SEL(setVertexStride_, + "setVertexStride:"); +_MTL_PRIVATE_DEF_SEL(setVertexTexture_atIndex_, + "setVertexTexture:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexTextures_withRange_, + "setVertexTextures:withRange:"); +_MTL_PRIVATE_DEF_SEL(setVertexVisibleFunctionTable_atBufferIndex_, + "setVertexVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVertexVisibleFunctionTables_withBufferRange_, + "setVertexVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setViewport_, + "setViewport:"); +_MTL_PRIVATE_DEF_SEL(setViewports_count_, + "setViewports:count:"); +_MTL_PRIVATE_DEF_SEL(setVisibilityResultBuffer_, + "setVisibilityResultBuffer:"); +_MTL_PRIVATE_DEF_SEL(setVisibilityResultMode_offset_, + "setVisibilityResultMode:offset:"); +_MTL_PRIVATE_DEF_SEL(setVisibilityResultType_, + "setVisibilityResultType:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTable_atBufferIndex_, + "setVisibleFunctionTable:atBufferIndex:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTable_atIndex_, + "setVisibleFunctionTable:atIndex:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTables_withBufferRange_, + "setVisibleFunctionTables:withBufferRange:"); +_MTL_PRIVATE_DEF_SEL(setVisibleFunctionTables_withRange_, + "setVisibleFunctionTables:withRange:"); +_MTL_PRIVATE_DEF_SEL(setWidth_, + "setWidth:"); +_MTL_PRIVATE_DEF_SEL(setWriteMask_, + "setWriteMask:"); +_MTL_PRIVATE_DEF_SEL(shaderReflection, + "shaderReflection"); +_MTL_PRIVATE_DEF_SEL(shaderValidation, + "shaderValidation"); +_MTL_PRIVATE_DEF_SEL(sharedCaptureManager, + "sharedCaptureManager"); +_MTL_PRIVATE_DEF_SEL(sharedListener, + "sharedListener"); +_MTL_PRIVATE_DEF_SEL(shouldMaximizeConcurrentCompilation, + "shouldMaximizeConcurrentCompilation"); +_MTL_PRIVATE_DEF_SEL(signalDrawable_, + "signalDrawable:"); +_MTL_PRIVATE_DEF_SEL(signalEvent_value_, + "signalEvent:value:"); +_MTL_PRIVATE_DEF_SEL(signaledValue, + "signaledValue"); +_MTL_PRIVATE_DEF_SEL(size, + "size"); +_MTL_PRIVATE_DEF_SEL(sizeOfCounterHeapEntry_, + "sizeOfCounterHeapEntry:"); +_MTL_PRIVATE_DEF_SEL(slice, + "slice"); +_MTL_PRIVATE_DEF_SEL(sliceRange, + "sliceRange"); +_MTL_PRIVATE_DEF_SEL(source, + "source"); +_MTL_PRIVATE_DEF_SEL(sourceAlphaBlendFactor, + "sourceAlphaBlendFactor"); +_MTL_PRIVATE_DEF_SEL(sourceRGBBlendFactor, + "sourceRGBBlendFactor"); +_MTL_PRIVATE_DEF_SEL(sparseBufferTier, + "sparseBufferTier"); +_MTL_PRIVATE_DEF_SEL(sparsePageSize, + "sparsePageSize"); +_MTL_PRIVATE_DEF_SEL(sparseTextureTier, + "sparseTextureTier"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeInBytes, + "sparseTileSizeInBytes"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeInBytesForSparsePageSize_, + "sparseTileSizeInBytesForSparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_, + "sparseTileSizeWithTextureType:pixelFormat:sampleCount:"); +_MTL_PRIVATE_DEF_SEL(sparseTileSizeWithTextureType_pixelFormat_sampleCount_sparsePageSize_, + "sparseTileSizeWithTextureType:pixelFormat:sampleCount:sparsePageSize:"); +_MTL_PRIVATE_DEF_SEL(specializedName, + "specializedName"); +_MTL_PRIVATE_DEF_SEL(stageInputAttributes, + "stageInputAttributes"); +_MTL_PRIVATE_DEF_SEL(stageInputDescriptor, + "stageInputDescriptor"); +_MTL_PRIVATE_DEF_SEL(stageInputOutputDescriptor, + "stageInputOutputDescriptor"); +_MTL_PRIVATE_DEF_SEL(stages, + "stages"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithCommandQueue_, + "startCaptureWithCommandQueue:"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithDescriptor_error_, + "startCaptureWithDescriptor:error:"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithDevice_, + "startCaptureWithDevice:"); +_MTL_PRIVATE_DEF_SEL(startCaptureWithScope_, + "startCaptureWithScope:"); +_MTL_PRIVATE_DEF_SEL(startOfEncoderSampleIndex, + "startOfEncoderSampleIndex"); +_MTL_PRIVATE_DEF_SEL(startOfFragmentSampleIndex, + "startOfFragmentSampleIndex"); +_MTL_PRIVATE_DEF_SEL(startOfVertexSampleIndex, + "startOfVertexSampleIndex"); +_MTL_PRIVATE_DEF_SEL(staticLinkingDescriptor, + "staticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(staticThreadgroupMemoryLength, + "staticThreadgroupMemoryLength"); +_MTL_PRIVATE_DEF_SEL(status, + "status"); +_MTL_PRIVATE_DEF_SEL(stencilAttachment, + "stencilAttachment"); +_MTL_PRIVATE_DEF_SEL(stencilAttachmentPixelFormat, + "stencilAttachmentPixelFormat"); +_MTL_PRIVATE_DEF_SEL(stencilCompareFunction, + "stencilCompareFunction"); +_MTL_PRIVATE_DEF_SEL(stencilFailureOperation, + "stencilFailureOperation"); +_MTL_PRIVATE_DEF_SEL(stencilResolveFilter, + "stencilResolveFilter"); +_MTL_PRIVATE_DEF_SEL(stepFunction, + "stepFunction"); +_MTL_PRIVATE_DEF_SEL(stepRate, + "stepRate"); +_MTL_PRIVATE_DEF_SEL(stopCapture, + "stopCapture"); +_MTL_PRIVATE_DEF_SEL(storageMode, + "storageMode"); +_MTL_PRIVATE_DEF_SEL(storeAction, + "storeAction"); +_MTL_PRIVATE_DEF_SEL(storeActionOptions, + "storeActionOptions"); +_MTL_PRIVATE_DEF_SEL(stride, + "stride"); +_MTL_PRIVATE_DEF_SEL(strides, + "strides"); +_MTL_PRIVATE_DEF_SEL(structType, + "structType"); +_MTL_PRIVATE_DEF_SEL(supportAddingBinaryFunctions, + "supportAddingBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(supportAddingFragmentBinaryFunctions, + "supportAddingFragmentBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(supportAddingVertexBinaryFunctions, + "supportAddingVertexBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(supportArgumentBuffers, + "supportArgumentBuffers"); +_MTL_PRIVATE_DEF_SEL(supportAttributeStrides, + "supportAttributeStrides"); +_MTL_PRIVATE_DEF_SEL(supportBinaryLinking, + "supportBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportColorAttachmentMapping, + "supportColorAttachmentMapping"); +_MTL_PRIVATE_DEF_SEL(supportDynamicAttributeStride, + "supportDynamicAttributeStride"); +_MTL_PRIVATE_DEF_SEL(supportFragmentBinaryLinking, + "supportFragmentBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportIndirectCommandBuffers, + "supportIndirectCommandBuffers"); +_MTL_PRIVATE_DEF_SEL(supportMeshBinaryLinking, + "supportMeshBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportObjectBinaryLinking, + "supportObjectBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supportRayTracing, + "supportRayTracing"); +_MTL_PRIVATE_DEF_SEL(supportVertexBinaryLinking, + "supportVertexBinaryLinking"); +_MTL_PRIVATE_DEF_SEL(supports32BitFloatFiltering, + "supports32BitFloatFiltering"); +_MTL_PRIVATE_DEF_SEL(supports32BitMSAA, + "supports32BitMSAA"); +_MTL_PRIVATE_DEF_SEL(supportsBCTextureCompression, + "supportsBCTextureCompression"); +_MTL_PRIVATE_DEF_SEL(supportsCounterSampling_, + "supportsCounterSampling:"); +_MTL_PRIVATE_DEF_SEL(supportsDestination_, + "supportsDestination:"); +_MTL_PRIVATE_DEF_SEL(supportsDynamicLibraries, + "supportsDynamicLibraries"); +_MTL_PRIVATE_DEF_SEL(supportsFamily_, + "supportsFamily:"); +_MTL_PRIVATE_DEF_SEL(supportsFeatureSet_, + "supportsFeatureSet:"); +_MTL_PRIVATE_DEF_SEL(supportsFunctionPointers, + "supportsFunctionPointers"); +_MTL_PRIVATE_DEF_SEL(supportsFunctionPointersFromRender, + "supportsFunctionPointersFromRender"); +_MTL_PRIVATE_DEF_SEL(supportsPrimitiveMotionBlur, + "supportsPrimitiveMotionBlur"); +_MTL_PRIVATE_DEF_SEL(supportsPullModelInterpolation, + "supportsPullModelInterpolation"); +_MTL_PRIVATE_DEF_SEL(supportsQueryTextureLOD, + "supportsQueryTextureLOD"); +_MTL_PRIVATE_DEF_SEL(supportsRasterizationRateMapWithLayerCount_, + "supportsRasterizationRateMapWithLayerCount:"); +_MTL_PRIVATE_DEF_SEL(supportsRaytracing, + "supportsRaytracing"); +_MTL_PRIVATE_DEF_SEL(supportsRaytracingFromRender, + "supportsRaytracingFromRender"); +_MTL_PRIVATE_DEF_SEL(supportsRenderDynamicLibraries, + "supportsRenderDynamicLibraries"); +_MTL_PRIVATE_DEF_SEL(supportsShaderBarycentricCoordinates, + "supportsShaderBarycentricCoordinates"); +_MTL_PRIVATE_DEF_SEL(supportsTextureSampleCount_, + "supportsTextureSampleCount:"); +_MTL_PRIVATE_DEF_SEL(supportsVertexAmplificationCount_, + "supportsVertexAmplificationCount:"); +_MTL_PRIVATE_DEF_SEL(swizzle, + "swizzle"); +_MTL_PRIVATE_DEF_SEL(synchronizeResource_, + "synchronizeResource:"); +_MTL_PRIVATE_DEF_SEL(synchronizeTexture_slice_level_, + "synchronizeTexture:slice:level:"); +_MTL_PRIVATE_DEF_SEL(tAddressMode, + "tAddressMode"); +_MTL_PRIVATE_DEF_SEL(tailSizeInBytes, + "tailSizeInBytes"); +_MTL_PRIVATE_DEF_SEL(tensorDataType, + "tensorDataType"); +_MTL_PRIVATE_DEF_SEL(tensorReferenceType, + "tensorReferenceType"); +_MTL_PRIVATE_DEF_SEL(tensorSizeAndAlignWithDescriptor_, + "tensorSizeAndAlignWithDescriptor:"); +_MTL_PRIVATE_DEF_SEL(tessellationControlPointIndexType, + "tessellationControlPointIndexType"); +_MTL_PRIVATE_DEF_SEL(tessellationFactorFormat, + "tessellationFactorFormat"); +_MTL_PRIVATE_DEF_SEL(tessellationFactorStepFunction, + "tessellationFactorStepFunction"); +_MTL_PRIVATE_DEF_SEL(tessellationOutputWindingOrder, + "tessellationOutputWindingOrder"); +_MTL_PRIVATE_DEF_SEL(tessellationPartitionMode, + "tessellationPartitionMode"); +_MTL_PRIVATE_DEF_SEL(texture, + "texture"); +_MTL_PRIVATE_DEF_SEL(texture2DDescriptorWithPixelFormat_width_height_mipmapped_, + "texture2DDescriptorWithPixelFormat:width:height:mipmapped:"); +_MTL_PRIVATE_DEF_SEL(textureBarrier, + "textureBarrier"); +_MTL_PRIVATE_DEF_SEL(textureBufferDescriptorWithPixelFormat_width_resourceOptions_usage_, + "textureBufferDescriptorWithPixelFormat:width:resourceOptions:usage:"); +_MTL_PRIVATE_DEF_SEL(textureCubeDescriptorWithPixelFormat_size_mipmapped_, + "textureCubeDescriptorWithPixelFormat:size:mipmapped:"); +_MTL_PRIVATE_DEF_SEL(textureDataType, + "textureDataType"); +_MTL_PRIVATE_DEF_SEL(textureReferenceType, + "textureReferenceType"); +_MTL_PRIVATE_DEF_SEL(textureType, + "textureType"); +_MTL_PRIVATE_DEF_SEL(threadExecutionWidth, + "threadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(threadGroupSizeIsMultipleOfThreadExecutionWidth, + "threadGroupSizeIsMultipleOfThreadExecutionWidth"); +_MTL_PRIVATE_DEF_SEL(threadgroupMemoryAlignment, + "threadgroupMemoryAlignment"); +_MTL_PRIVATE_DEF_SEL(threadgroupMemoryDataSize, + "threadgroupMemoryDataSize"); +_MTL_PRIVATE_DEF_SEL(threadgroupMemoryLength, + "threadgroupMemoryLength"); +_MTL_PRIVATE_DEF_SEL(threadgroupSizeMatchesTileSize, + "threadgroupSizeMatchesTileSize"); +_MTL_PRIVATE_DEF_SEL(tileAdditionalBinaryFunctions, + "tileAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(tileArguments, + "tileArguments"); +_MTL_PRIVATE_DEF_SEL(tileBindings, + "tileBindings"); +_MTL_PRIVATE_DEF_SEL(tileBuffers, + "tileBuffers"); +_MTL_PRIVATE_DEF_SEL(tileFunction, + "tileFunction"); +_MTL_PRIVATE_DEF_SEL(tileFunctionDescriptor, + "tileFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(tileHeight, + "tileHeight"); +_MTL_PRIVATE_DEF_SEL(tileLinkingDescriptor, + "tileLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(tileWidth, + "tileWidth"); +_MTL_PRIVATE_DEF_SEL(transformationMatrixBuffer, + "transformationMatrixBuffer"); +_MTL_PRIVATE_DEF_SEL(transformationMatrixBufferOffset, + "transformationMatrixBufferOffset"); +_MTL_PRIVATE_DEF_SEL(transformationMatrixLayout, + "transformationMatrixLayout"); +_MTL_PRIVATE_DEF_SEL(triangleCount, + "triangleCount"); +_MTL_PRIVATE_DEF_SEL(tryCancel, + "tryCancel"); +_MTL_PRIVATE_DEF_SEL(type, + "type"); +_MTL_PRIVATE_DEF_SEL(updateBufferMappings_heap_operations_count_, + "updateBufferMappings:heap:operations:count:"); +_MTL_PRIVATE_DEF_SEL(updateFence_, + "updateFence:"); +_MTL_PRIVATE_DEF_SEL(updateFence_afterEncoderStages_, + "updateFence:afterEncoderStages:"); +_MTL_PRIVATE_DEF_SEL(updateFence_afterStages_, + "updateFence:afterStages:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMapping_mode_indirectBuffer_indirectBufferOffset_, + "updateTextureMapping:mode:indirectBuffer:indirectBufferOffset:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMapping_mode_region_mipLevel_slice_, + "updateTextureMapping:mode:region:mipLevel:slice:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMappings_heap_operations_count_, + "updateTextureMappings:heap:operations:count:"); +_MTL_PRIVATE_DEF_SEL(updateTextureMappings_mode_regions_mipLevels_slices_numRegions_, + "updateTextureMappings:mode:regions:mipLevels:slices:numRegions:"); +_MTL_PRIVATE_DEF_SEL(url, + "url"); +_MTL_PRIVATE_DEF_SEL(usage, + "usage"); +_MTL_PRIVATE_DEF_SEL(useHeap_, + "useHeap:"); +_MTL_PRIVATE_DEF_SEL(useHeap_stages_, + "useHeap:stages:"); +_MTL_PRIVATE_DEF_SEL(useHeaps_count_, + "useHeaps:count:"); +_MTL_PRIVATE_DEF_SEL(useHeaps_count_stages_, + "useHeaps:count:stages:"); +_MTL_PRIVATE_DEF_SEL(useResidencySet_, + "useResidencySet:"); +_MTL_PRIVATE_DEF_SEL(useResidencySets_count_, + "useResidencySets:count:"); +_MTL_PRIVATE_DEF_SEL(useResource_usage_, + "useResource:usage:"); +_MTL_PRIVATE_DEF_SEL(useResource_usage_stages_, + "useResource:usage:stages:"); +_MTL_PRIVATE_DEF_SEL(useResources_count_usage_, + "useResources:count:usage:"); +_MTL_PRIVATE_DEF_SEL(useResources_count_usage_stages_, + "useResources:count:usage:stages:"); +_MTL_PRIVATE_DEF_SEL(usedSize, + "usedSize"); +_MTL_PRIVATE_DEF_SEL(vertexAdditionalBinaryFunctions, + "vertexAdditionalBinaryFunctions"); +_MTL_PRIVATE_DEF_SEL(vertexArguments, + "vertexArguments"); +_MTL_PRIVATE_DEF_SEL(vertexAttributes, + "vertexAttributes"); +_MTL_PRIVATE_DEF_SEL(vertexBindings, + "vertexBindings"); +_MTL_PRIVATE_DEF_SEL(vertexBuffer, + "vertexBuffer"); +_MTL_PRIVATE_DEF_SEL(vertexBufferOffset, + "vertexBufferOffset"); +_MTL_PRIVATE_DEF_SEL(vertexBuffers, + "vertexBuffers"); +_MTL_PRIVATE_DEF_SEL(vertexDescriptor, + "vertexDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexFormat, + "vertexFormat"); +_MTL_PRIVATE_DEF_SEL(vertexFunction, + "vertexFunction"); +_MTL_PRIVATE_DEF_SEL(vertexFunctionDescriptor, + "vertexFunctionDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexLinkedFunctions, + "vertexLinkedFunctions"); +_MTL_PRIVATE_DEF_SEL(vertexLinkingDescriptor, + "vertexLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexPreloadedLibraries, + "vertexPreloadedLibraries"); +_MTL_PRIVATE_DEF_SEL(vertexStaticLinkingDescriptor, + "vertexStaticLinkingDescriptor"); +_MTL_PRIVATE_DEF_SEL(vertexStride, + "vertexStride"); +_MTL_PRIVATE_DEF_SEL(vertical, + "vertical"); +_MTL_PRIVATE_DEF_SEL(verticalSampleStorage, + "verticalSampleStorage"); +_MTL_PRIVATE_DEF_SEL(visibilityResultBuffer, + "visibilityResultBuffer"); +_MTL_PRIVATE_DEF_SEL(visibilityResultType, + "visibilityResultType"); +_MTL_PRIVATE_DEF_SEL(visibleFunctionTableDescriptor, + "visibleFunctionTableDescriptor"); +_MTL_PRIVATE_DEF_SEL(waitForDrawable_, + "waitForDrawable:"); +_MTL_PRIVATE_DEF_SEL(waitForEvent_value_, + "waitForEvent:value:"); +_MTL_PRIVATE_DEF_SEL(waitForFence_, + "waitForFence:"); +_MTL_PRIVATE_DEF_SEL(waitForFence_beforeEncoderStages_, + "waitForFence:beforeEncoderStages:"); +_MTL_PRIVATE_DEF_SEL(waitForFence_beforeStages_, + "waitForFence:beforeStages:"); +_MTL_PRIVATE_DEF_SEL(waitUntilCompleted, + "waitUntilCompleted"); +_MTL_PRIVATE_DEF_SEL(waitUntilScheduled, + "waitUntilScheduled"); +_MTL_PRIVATE_DEF_SEL(waitUntilSignaledValue_timeoutMS_, + "waitUntilSignaledValue:timeoutMS:"); +_MTL_PRIVATE_DEF_SEL(width, + "width"); +_MTL_PRIVATE_DEF_SEL(writeCompactedAccelerationStructureSize_toBuffer_, + "writeCompactedAccelerationStructureSize:toBuffer:"); +_MTL_PRIVATE_DEF_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_, + "writeCompactedAccelerationStructureSize:toBuffer:offset:"); +_MTL_PRIVATE_DEF_SEL(writeCompactedAccelerationStructureSize_toBuffer_offset_sizeDataType_, + "writeCompactedAccelerationStructureSize:toBuffer:offset:sizeDataType:"); +_MTL_PRIVATE_DEF_SEL(writeMask, + "writeMask"); +_MTL_PRIVATE_DEF_SEL(writeTimestampIntoHeap_atIndex_, + "writeTimestampIntoHeap:atIndex:"); +_MTL_PRIVATE_DEF_SEL(writeTimestampWithGranularity_afterStage_intoHeap_atIndex_, + "writeTimestampWithGranularity:afterStage:intoHeap:atIndex:"); +_MTL_PRIVATE_DEF_SEL(writeTimestampWithGranularity_intoHeap_atIndex_, + "writeTimestampWithGranularity:intoHeap:atIndex:"); + +} diff --git a/dist/include/metal_cpp/Metal/MTLHeap.hpp b/dist/include/metal_cpp/Metal/MTLHeap.hpp new file mode 100644 index 0000000..251b284 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLHeap.hpp @@ -0,0 +1,318 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLHeap.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" + +namespace MTL +{ +class AccelerationStructure; +class AccelerationStructureDescriptor; +class Buffer; +class Device; +class HeapDescriptor; +class Texture; +class TextureDescriptor; +_MTL_ENUM(NS::Integer, HeapType) { + HeapTypeAutomatic = 0, + HeapTypePlacement = 1, + HeapTypeSparse = 2, +}; + +class HeapDescriptor : public NS::Copying +{ +public: + static HeapDescriptor* alloc(); + + CPUCacheMode cpuCacheMode() const; + + HazardTrackingMode hazardTrackingMode() const; + + HeapDescriptor* init(); + + SparsePageSize maxCompatiblePlacementSparsePageSize() const; + + ResourceOptions resourceOptions() const; + + void setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode); + + void setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode); + + void setMaxCompatiblePlacementSparsePageSize(MTL::SparsePageSize maxCompatiblePlacementSparsePageSize); + + void setResourceOptions(MTL::ResourceOptions resourceOptions); + + void setSize(NS::UInteger size); + + void setSparsePageSize(MTL::SparsePageSize sparsePageSize); + + void setStorageMode(MTL::StorageMode storageMode); + + void setType(MTL::HeapType type); + + NS::UInteger size() const; + SparsePageSize sparsePageSize() const; + + StorageMode storageMode() const; + + HeapType type() const; +}; +class Heap : public NS::Referencing +{ +public: + CPUCacheMode cpuCacheMode() const; + + NS::UInteger currentAllocatedSize() const; + + Device* device() const; + + HazardTrackingMode hazardTrackingMode() const; + + NS::String* label() const; + + NS::UInteger maxAvailableSize(NS::UInteger alignment); + + AccelerationStructure* newAccelerationStructure(NS::UInteger size); + AccelerationStructure* newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor); + AccelerationStructure* newAccelerationStructure(NS::UInteger size, NS::UInteger offset); + AccelerationStructure* newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor, NS::UInteger offset); + + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options); + Buffer* newBuffer(NS::UInteger length, MTL::ResourceOptions options, NS::UInteger offset); + + Texture* newTexture(const MTL::TextureDescriptor* descriptor); + Texture* newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset); + + ResourceOptions resourceOptions() const; + + void setLabel(const NS::String* label); + + PurgeableState setPurgeableState(MTL::PurgeableState state); + + NS::UInteger size() const; + + StorageMode storageMode() const; + + HeapType type() const; + + NS::UInteger usedSize() const; +}; + +} +_MTL_INLINE MTL::HeapDescriptor* MTL::HeapDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLHeapDescriptor)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::HeapDescriptor::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::HeapDescriptor::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE MTL::HeapDescriptor* MTL::HeapDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::SparsePageSize MTL::HeapDescriptor::maxCompatiblePlacementSparsePageSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCompatiblePlacementSparsePageSize)); +} + +_MTL_INLINE MTL::ResourceOptions MTL::HeapDescriptor::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::HeapDescriptor::setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCpuCacheMode_), cpuCacheMode); +} + +_MTL_INLINE void MTL::HeapDescriptor::setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHazardTrackingMode_), hazardTrackingMode); +} + +_MTL_INLINE void MTL::HeapDescriptor::setMaxCompatiblePlacementSparsePageSize(MTL::SparsePageSize maxCompatiblePlacementSparsePageSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCompatiblePlacementSparsePageSize_), maxCompatiblePlacementSparsePageSize); +} + +_MTL_INLINE void MTL::HeapDescriptor::setResourceOptions(MTL::ResourceOptions resourceOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceOptions_), resourceOptions); +} + +_MTL_INLINE void MTL::HeapDescriptor::setSize(NS::UInteger size) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSize_), size); +} + +_MTL_INLINE void MTL::HeapDescriptor::setSparsePageSize(MTL::SparsePageSize sparsePageSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSparsePageSize_), sparsePageSize); +} + +_MTL_INLINE void MTL::HeapDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE void MTL::HeapDescriptor::setType(MTL::HeapType type) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setType_), type); +} + +_MTL_INLINE NS::UInteger MTL::HeapDescriptor::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} + +_MTL_INLINE MTL::SparsePageSize MTL::HeapDescriptor::sparsePageSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparsePageSize)); +} + +_MTL_INLINE MTL::StorageMode MTL::HeapDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::HeapType MTL::HeapDescriptor::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::Heap::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE NS::UInteger MTL::Heap::currentAllocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(currentAllocatedSize)); +} + +_MTL_INLINE MTL::Device* MTL::Heap::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::Heap::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE NS::String* MTL::Heap::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::Heap::maxAvailableSize(NS::UInteger alignment) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxAvailableSizeWithAlignment_), alignment); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(NS::UInteger size) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithSize_), size); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(NS::UInteger size, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithSize_offset_), size, offset); +} + +_MTL_INLINE MTL::AccelerationStructure* MTL::Heap::newAccelerationStructure(const MTL::AccelerationStructureDescriptor* descriptor, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newAccelerationStructureWithDescriptor_offset_), descriptor, offset); +} + +_MTL_INLINE MTL::Buffer* MTL::Heap::newBuffer(NS::UInteger length, MTL::ResourceOptions options) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_), length, options); +} + +_MTL_INLINE MTL::Buffer* MTL::Heap::newBuffer(NS::UInteger length, MTL::ResourceOptions options, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newBufferWithLength_options_offset_), length, options, offset); +} + +_MTL_INLINE MTL::Texture* MTL::Heap::newTexture(const MTL::TextureDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Heap::newTexture(const MTL::TextureDescriptor* descriptor, NS::UInteger offset) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureWithDescriptor_offset_), descriptor, offset); +} + +_MTL_INLINE MTL::ResourceOptions MTL::Heap::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::Heap::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::PurgeableState MTL::Heap::setPurgeableState(MTL::PurgeableState state) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setPurgeableState_), state); +} + +_MTL_INLINE NS::UInteger MTL::Heap::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} + +_MTL_INLINE MTL::StorageMode MTL::Heap::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::HeapType MTL::Heap::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE NS::UInteger MTL::Heap::usedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usedSize)); +} diff --git a/dist/include/metal_cpp/Metal/MTLIOCommandBuffer.hpp b/dist/include/metal_cpp/Metal/MTLIOCommandBuffer.hpp new file mode 100644 index 0000000..9402318 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLIOCommandBuffer.hpp @@ -0,0 +1,182 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIOCommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class IOCommandBuffer; +class IOFileHandle; +class SharedEvent; +class Texture; +_MTL_ENUM(NS::Integer, IOStatus) { + IOStatusPending = 0, + IOStatusCancelled = 1, + IOStatusError = 2, + IOStatusComplete = 3, +}; + +using IOCommandBufferHandler = void (^)(MTL::IOCommandBuffer*); +using IOCommandBufferHandlerFunction = std::function; + +class IOCommandBuffer : public NS::Referencing +{ +public: + void addBarrier(); + + void addCompletedHandler(const MTL::IOCommandBufferHandler block); + void addCompletedHandler(const MTL::IOCommandBufferHandlerFunction& function); + + void commit(); + + void copyStatusToBuffer(const MTL::Buffer* buffer, NS::UInteger offset); + + void enqueue(); + + NS::Error* error() const; + + NS::String* label() const; + + void loadBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset); + + void loadBytes(const void* pointer, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset); + + void loadTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level, MTL::Size size, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Origin destinationOrigin, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset); + + void popDebugGroup(); + + void pushDebugGroup(const NS::String* string); + + void setLabel(const NS::String* label); + + void signalEvent(const MTL::SharedEvent* event, uint64_t value); + + IOStatus status() const; + + void tryCancel(); + + void wait(const MTL::SharedEvent* event, uint64_t value); + void waitUntilCompleted(); +}; + +} +_MTL_INLINE void MTL::IOCommandBuffer::addBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addBarrier)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::addCompletedHandler(const MTL::IOCommandBufferHandler block) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addCompletedHandler_), block); +} + +_MTL_INLINE void MTL::IOCommandBuffer::addCompletedHandler(const MTL::IOCommandBufferHandlerFunction& function) +{ + __block MTL::IOCommandBufferHandlerFunction blockFunction = function; + addCompletedHandler(^(MTL::IOCommandBuffer* pCommandBuffer) { blockFunction(pCommandBuffer); }); +} + +_MTL_INLINE void MTL::IOCommandBuffer::commit() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::copyStatusToBuffer(const MTL::Buffer* buffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyStatusToBuffer_offset_), buffer, offset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::enqueue() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(enqueue)); +} + +_MTL_INLINE NS::Error* MTL::IOCommandBuffer::error() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(error)); +} + +_MTL_INLINE NS::String* MTL::IOCommandBuffer::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::loadBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(loadBuffer_offset_size_sourceHandle_sourceHandleOffset_), buffer, offset, size, sourceHandle, sourceHandleOffset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::loadBytes(const void* pointer, NS::UInteger size, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(loadBytes_size_sourceHandle_sourceHandleOffset_), pointer, size, sourceHandle, sourceHandleOffset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::loadTexture(const MTL::Texture* texture, NS::UInteger slice, NS::UInteger level, MTL::Size size, NS::UInteger sourceBytesPerRow, NS::UInteger sourceBytesPerImage, MTL::Origin destinationOrigin, const MTL::IOFileHandle* sourceHandle, NS::UInteger sourceHandleOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(loadTexture_slice_level_size_sourceBytesPerRow_sourceBytesPerImage_destinationOrigin_sourceHandle_sourceHandleOffset_), texture, slice, level, size, sourceBytesPerRow, sourceBytesPerImage, destinationOrigin, sourceHandle, sourceHandleOffset); +} + +_MTL_INLINE void MTL::IOCommandBuffer::popDebugGroup() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(popDebugGroup)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::pushDebugGroup(const NS::String* string) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(pushDebugGroup_), string); +} + +_MTL_INLINE void MTL::IOCommandBuffer::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::IOCommandBuffer::signalEvent(const MTL::SharedEvent* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(signalEvent_value_), event, value); +} + +_MTL_INLINE MTL::IOStatus MTL::IOCommandBuffer::status() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(status)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::tryCancel() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(tryCancel)); +} + +_MTL_INLINE void MTL::IOCommandBuffer::wait(const MTL::SharedEvent* event, uint64_t value) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForEvent_value_), event, value); +} + +_MTL_INLINE void MTL::IOCommandBuffer::waitUntilCompleted() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitUntilCompleted)); +} diff --git a/dist/include/metal_cpp/Metal/MTLIOCommandQueue.hpp b/dist/include/metal_cpp/Metal/MTLIOCommandQueue.hpp new file mode 100644 index 0000000..ac956c8 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLIOCommandQueue.hpp @@ -0,0 +1,211 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIOCommandQueue.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class Buffer; +class IOCommandBuffer; +class IOCommandQueueDescriptor; +class IOScratchBuffer; +class IOScratchBufferAllocator; +_MTL_ENUM(NS::Integer, IOPriority) { + IOPriorityHigh = 0, + IOPriorityNormal = 1, + IOPriorityLow = 2, +}; + +_MTL_ENUM(NS::Integer, IOCommandQueueType) { + IOCommandQueueTypeConcurrent = 0, + IOCommandQueueTypeSerial = 1, +}; + +_MTL_ENUM(NS::Integer, IOError) { + IOErrorURLInvalid = 1, + IOErrorInternal = 2, +}; + +_MTL_CONST(NS::ErrorDomain, IOErrorDomain); +class IOCommandQueue : public NS::Referencing +{ +public: + IOCommandBuffer* commandBuffer(); + IOCommandBuffer* commandBufferWithUnretainedReferences(); + + void enqueueBarrier(); + + NS::String* label() const; + void setLabel(const NS::String* label); +}; +class IOScratchBuffer : public NS::Referencing +{ +public: + Buffer* buffer() const; +}; +class IOScratchBufferAllocator : public NS::Referencing +{ +public: + IOScratchBuffer* newScratchBuffer(NS::UInteger minimumSize); +}; +class IOCommandQueueDescriptor : public NS::Copying +{ +public: + static IOCommandQueueDescriptor* alloc(); + + IOCommandQueueDescriptor* init(); + + NS::UInteger maxCommandBufferCount() const; + + NS::UInteger maxCommandsInFlight() const; + + IOPriority priority() const; + + IOScratchBufferAllocator* scratchBufferAllocator() const; + + void setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount); + + void setMaxCommandsInFlight(NS::UInteger maxCommandsInFlight); + + void setPriority(MTL::IOPriority priority); + + void setScratchBufferAllocator(const MTL::IOScratchBufferAllocator* scratchBufferAllocator); + + void setType(MTL::IOCommandQueueType type); + IOCommandQueueType type() const; +}; +class IOFileHandle : public NS::Referencing +{ +public: + NS::String* label() const; + void setLabel(const NS::String* label); +}; + +} +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, IOErrorDomain); +_MTL_INLINE MTL::IOCommandBuffer* MTL::IOCommandQueue::commandBuffer() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBuffer)); +} + +_MTL_INLINE MTL::IOCommandBuffer* MTL::IOCommandQueue::commandBufferWithUnretainedReferences() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandBufferWithUnretainedReferences)); +} + +_MTL_INLINE void MTL::IOCommandQueue::enqueueBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(enqueueBarrier)); +} + +_MTL_INLINE NS::String* MTL::IOCommandQueue::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::IOCommandQueue::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::Buffer* MTL::IOScratchBuffer::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE MTL::IOScratchBuffer* MTL::IOScratchBufferAllocator::newScratchBuffer(NS::UInteger minimumSize) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newScratchBufferWithMinimumSize_), minimumSize); +} + +_MTL_INLINE MTL::IOCommandQueueDescriptor* MTL::IOCommandQueueDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIOCommandQueueDescriptor)); +} + +_MTL_INLINE MTL::IOCommandQueueDescriptor* MTL::IOCommandQueueDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::IOCommandQueueDescriptor::maxCommandBufferCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCommandBufferCount)); +} + +_MTL_INLINE NS::UInteger MTL::IOCommandQueueDescriptor::maxCommandsInFlight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCommandsInFlight)); +} + +_MTL_INLINE MTL::IOPriority MTL::IOCommandQueueDescriptor::priority() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(priority)); +} + +_MTL_INLINE MTL::IOScratchBufferAllocator* MTL::IOCommandQueueDescriptor::scratchBufferAllocator() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(scratchBufferAllocator)); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setMaxCommandBufferCount(NS::UInteger maxCommandBufferCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCommandBufferCount_), maxCommandBufferCount); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setMaxCommandsInFlight(NS::UInteger maxCommandsInFlight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCommandsInFlight_), maxCommandsInFlight); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setPriority(MTL::IOPriority priority) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPriority_), priority); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setScratchBufferAllocator(const MTL::IOScratchBufferAllocator* scratchBufferAllocator) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScratchBufferAllocator_), scratchBufferAllocator); +} + +_MTL_INLINE void MTL::IOCommandQueueDescriptor::setType(MTL::IOCommandQueueType type) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setType_), type); +} + +_MTL_INLINE MTL::IOCommandQueueType MTL::IOCommandQueueDescriptor::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE NS::String* MTL::IOFileHandle::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::IOFileHandle::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} diff --git a/dist/include/metal_cpp/Metal/MTLIOCompressor.hpp b/dist/include/metal_cpp/Metal/MTLIOCompressor.hpp new file mode 100644 index 0000000..920fa61 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLIOCompressor.hpp @@ -0,0 +1,94 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIOCompressor.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLDevice.hpp" + +#include "../Foundation/Foundation.hpp" + +namespace MTL +{ +using IOCompressionContext=void*; + +_MTL_ENUM(NS::Integer, IOCompressionStatus) { + IOCompressionStatusComplete = 0, + IOCompressionStatusError = 1, +}; + +size_t IOCompressionContextDefaultChunkSize(); + +IOCompressionContext IOCreateCompressionContext(const char* path, IOCompressionMethod type, size_t chunkSize); + +void IOCompressionContextAppendData(IOCompressionContext context, const void* data, size_t size); + +IOCompressionStatus IOFlushAndDestroyCompressionContext(IOCompressionContext context); + +} + +#if defined(MTL_PRIVATE_IMPLEMENTATION) + +namespace MTL::Private { + +MTL_DEF_FUNC(MTLIOCompressionContextDefaultChunkSize, size_t (*)(void)); + +MTL_DEF_FUNC( MTLIOCreateCompressionContext, void* (*)(const char*, MTL::IOCompressionMethod, size_t) ); + +MTL_DEF_FUNC( MTLIOCompressionContextAppendData, void (*)(void*, const void*, size_t) ); + +MTL_DEF_FUNC( MTLIOFlushAndDestroyCompressionContext, MTL::IOCompressionStatus (*)(void*) ); + +} + +_NS_EXPORT size_t MTL::IOCompressionContextDefaultChunkSize() +{ + return MTL::Private::MTLIOCompressionContextDefaultChunkSize(); +} + +_NS_EXPORT void* MTL::IOCreateCompressionContext(const char* path, IOCompressionMethod type, size_t chunkSize) +{ + if ( MTL::Private::MTLIOCreateCompressionContext ) + { + return MTL::Private::MTLIOCreateCompressionContext( path, type, chunkSize ); + } + return nullptr; +} + +_NS_EXPORT void MTL::IOCompressionContextAppendData(void* context, const void* data, size_t size) +{ + if ( MTL::Private::MTLIOCompressionContextAppendData ) + { + MTL::Private::MTLIOCompressionContextAppendData( context, data, size ); + } +} + +_NS_EXPORT MTL::IOCompressionStatus MTL::IOFlushAndDestroyCompressionContext(void* context) +{ + if ( MTL::Private::MTLIOFlushAndDestroyCompressionContext ) + { + return MTL::Private::MTLIOFlushAndDestroyCompressionContext( context ); + } + return MTL::IOCompressionStatusError; +} + +#endif diff --git a/dist/include/metal_cpp/Metal/MTLIndirectCommandBuffer.hpp b/dist/include/metal_cpp/Metal/MTLIndirectCommandBuffer.hpp new file mode 100644 index 0000000..6944d56 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLIndirectCommandBuffer.hpp @@ -0,0 +1,376 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIndirectCommandBuffer.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class IndirectCommandBufferDescriptor; +class IndirectComputeCommand; +class IndirectRenderCommand; + +_MTL_OPTIONS(NS::UInteger, IndirectCommandType) { + IndirectCommandTypeDraw = 1, + IndirectCommandTypeDrawIndexed = 1 << 1, + IndirectCommandTypeDrawPatches = 1 << 2, + IndirectCommandTypeDrawIndexedPatches = 1 << 3, + IndirectCommandTypeConcurrentDispatch = 1 << 5, + IndirectCommandTypeConcurrentDispatchThreads = 1 << 6, + IndirectCommandTypeDrawMeshThreadgroups = 1 << 7, + IndirectCommandTypeDrawMeshThreads = 1 << 8, +}; + +struct IndirectCommandBufferExecutionRange +{ + uint32_t location; + uint32_t length; +} _MTL_PACKED; + +class IndirectCommandBufferDescriptor : public NS::Copying +{ +public: + static IndirectCommandBufferDescriptor* alloc(); + + IndirectCommandType commandTypes() const; + + bool inheritBuffers() const; + + bool inheritCullMode() const; + + bool inheritDepthBias() const; + + bool inheritDepthClipMode() const; + + bool inheritDepthStencilState() const; + + bool inheritFrontFacingWinding() const; + + bool inheritPipelineState() const; + + bool inheritTriangleFillMode() const; + + IndirectCommandBufferDescriptor* init(); + + NS::UInteger maxFragmentBufferBindCount() const; + + NS::UInteger maxKernelBufferBindCount() const; + + NS::UInteger maxKernelThreadgroupMemoryBindCount() const; + + NS::UInteger maxMeshBufferBindCount() const; + + NS::UInteger maxObjectBufferBindCount() const; + + NS::UInteger maxObjectThreadgroupMemoryBindCount() const; + + NS::UInteger maxVertexBufferBindCount() const; + + void setCommandTypes(MTL::IndirectCommandType commandTypes); + + void setInheritBuffers(bool inheritBuffers); + + void setInheritCullMode(bool inheritCullMode); + + void setInheritDepthBias(bool inheritDepthBias); + + void setInheritDepthClipMode(bool inheritDepthClipMode); + + void setInheritDepthStencilState(bool inheritDepthStencilState); + + void setInheritFrontFacingWinding(bool inheritFrontFacingWinding); + + void setInheritPipelineState(bool inheritPipelineState); + + void setInheritTriangleFillMode(bool inheritTriangleFillMode); + + void setMaxFragmentBufferBindCount(NS::UInteger maxFragmentBufferBindCount); + + void setMaxKernelBufferBindCount(NS::UInteger maxKernelBufferBindCount); + + void setMaxKernelThreadgroupMemoryBindCount(NS::UInteger maxKernelThreadgroupMemoryBindCount); + + void setMaxMeshBufferBindCount(NS::UInteger maxMeshBufferBindCount); + + void setMaxObjectBufferBindCount(NS::UInteger maxObjectBufferBindCount); + + void setMaxObjectThreadgroupMemoryBindCount(NS::UInteger maxObjectThreadgroupMemoryBindCount); + + void setMaxVertexBufferBindCount(NS::UInteger maxVertexBufferBindCount); + + void setSupportColorAttachmentMapping(bool supportColorAttachmentMapping); + + void setSupportDynamicAttributeStride(bool supportDynamicAttributeStride); + + void setSupportRayTracing(bool supportRayTracing); + + bool supportColorAttachmentMapping() const; + + bool supportDynamicAttributeStride() const; + + bool supportRayTracing() const; +}; +class IndirectCommandBuffer : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + IndirectComputeCommand* indirectComputeCommand(NS::UInteger commandIndex); + + IndirectRenderCommand* indirectRenderCommand(NS::UInteger commandIndex); + + void reset(NS::Range range); + + NS::UInteger size() const; +}; + +} + +_MTL_INLINE MTL::IndirectCommandBufferDescriptor* MTL::IndirectCommandBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIndirectCommandBufferDescriptor)); +} + +_MTL_INLINE MTL::IndirectCommandType MTL::IndirectCommandBufferDescriptor::commandTypes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(commandTypes)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritBuffers)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritCullMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritCullMode)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritDepthBias() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritDepthBias)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritDepthClipMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritDepthClipMode)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritDepthStencilState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritDepthStencilState)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritFrontFacingWinding() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritFrontFacingWinding)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritPipelineState() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritPipelineState)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::inheritTriangleFillMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inheritTriangleFillMode)); +} + +_MTL_INLINE MTL::IndirectCommandBufferDescriptor* MTL::IndirectCommandBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxFragmentBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxFragmentBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxKernelBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxKernelBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxKernelThreadgroupMemoryBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxKernelThreadgroupMemoryBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxMeshBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxMeshBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxObjectBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxObjectBufferBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxObjectThreadgroupMemoryBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxObjectThreadgroupMemoryBindCount)); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBufferDescriptor::maxVertexBufferBindCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexBufferBindCount)); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setCommandTypes(MTL::IndirectCommandType commandTypes) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCommandTypes_), commandTypes); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritBuffers(bool inheritBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritBuffers_), inheritBuffers); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritCullMode(bool inheritCullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritCullMode_), inheritCullMode); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritDepthBias(bool inheritDepthBias) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritDepthBias_), inheritDepthBias); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritDepthClipMode(bool inheritDepthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritDepthClipMode_), inheritDepthClipMode); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritDepthStencilState(bool inheritDepthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritDepthStencilState_), inheritDepthStencilState); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritFrontFacingWinding(bool inheritFrontFacingWinding) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritFrontFacingWinding_), inheritFrontFacingWinding); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritPipelineState(bool inheritPipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritPipelineState_), inheritPipelineState); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setInheritTriangleFillMode(bool inheritTriangleFillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInheritTriangleFillMode_), inheritTriangleFillMode); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxFragmentBufferBindCount(NS::UInteger maxFragmentBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxFragmentBufferBindCount_), maxFragmentBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxKernelBufferBindCount(NS::UInteger maxKernelBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxKernelBufferBindCount_), maxKernelBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxKernelThreadgroupMemoryBindCount(NS::UInteger maxKernelThreadgroupMemoryBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxKernelThreadgroupMemoryBindCount_), maxKernelThreadgroupMemoryBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxMeshBufferBindCount(NS::UInteger maxMeshBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxMeshBufferBindCount_), maxMeshBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxObjectBufferBindCount(NS::UInteger maxObjectBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxObjectBufferBindCount_), maxObjectBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxObjectThreadgroupMemoryBindCount(NS::UInteger maxObjectThreadgroupMemoryBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxObjectThreadgroupMemoryBindCount_), maxObjectThreadgroupMemoryBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setMaxVertexBufferBindCount(NS::UInteger maxVertexBufferBindCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexBufferBindCount_), maxVertexBufferBindCount); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setSupportColorAttachmentMapping(bool supportColorAttachmentMapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportColorAttachmentMapping_), supportColorAttachmentMapping); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setSupportDynamicAttributeStride(bool supportDynamicAttributeStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportDynamicAttributeStride_), supportDynamicAttributeStride); +} + +_MTL_INLINE void MTL::IndirectCommandBufferDescriptor::setSupportRayTracing(bool supportRayTracing) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportRayTracing_), supportRayTracing); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::supportColorAttachmentMapping() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportColorAttachmentMapping)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::supportDynamicAttributeStride() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportDynamicAttributeStride)); +} + +_MTL_INLINE bool MTL::IndirectCommandBufferDescriptor::supportRayTracing() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportRayTracing)); +} + +_MTL_INLINE MTL::ResourceID MTL::IndirectCommandBuffer::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE MTL::IndirectComputeCommand* MTL::IndirectCommandBuffer::indirectComputeCommand(NS::UInteger commandIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indirectComputeCommandAtIndex_), commandIndex); +} + +_MTL_INLINE MTL::IndirectRenderCommand* MTL::IndirectCommandBuffer::indirectRenderCommand(NS::UInteger commandIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indirectRenderCommandAtIndex_), commandIndex); +} + +_MTL_INLINE void MTL::IndirectCommandBuffer::reset(NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(resetWithRange_), range); +} + +_MTL_INLINE NS::UInteger MTL::IndirectCommandBuffer::size() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(size)); +} diff --git a/dist/include/metal_cpp/Metal/MTLIndirectCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLIndirectCommandEncoder.hpp new file mode 100644 index 0000000..9e1a940 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLIndirectCommandEncoder.hpp @@ -0,0 +1,272 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIndirectCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class ComputePipelineState; +class RenderPipelineState; + +class IndirectRenderCommand : public NS::Referencing +{ +public: + void clearBarrier(); + + void drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride); + + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance); + + void drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride); + + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance); + + void reset(); + + void setBarrier(); + + void setCullMode(MTL::CullMode cullMode); + + void setDepthBias(float depthBias, float slopeScale, float clamp); + + void setDepthClipMode(MTL::DepthClipMode depthClipMode); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState); + + void setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + + void setFrontFacingWinding(MTL::Winding frontFacingWindning); + + void setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + + void setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + + void setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipelineState); + + void setTriangleFillMode(MTL::TriangleFillMode fillMode); + + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); +}; +class IndirectComputeCommand : public NS::Referencing +{ +public: + void clearBarrier(); + + void concurrentDispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup); + + void concurrentDispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup); + + void reset(); + + void setBarrier(); + + void setComputePipelineState(const MTL::ComputePipelineState* pipelineState); + + void setImageblockWidth(NS::UInteger width, NS::UInteger height); + + void setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + + void setStageInRegion(MTL::Region region); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); +}; + +} +_MTL_INLINE void MTL::IndirectRenderCommand::clearBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(clearBarrier)); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, controlPointIndexBuffer, controlPointIndexBufferOffset, instanceCount, baseInstance, buffer, offset, instanceStride); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_baseVertex_baseInstance_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset, instanceCount, baseVertex, baseInstance); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadgroupsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance, const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_tessellationFactorBuffer_tessellationFactorBufferOffset_tessellationFactorBufferInstanceStride_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, instanceCount, baseInstance, buffer, offset, instanceStride); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_), primitiveType, vertexStart, vertexCount, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBarrier)); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setCullMode(MTL::CullMode cullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCullMode_), cullMode); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setDepthBias(float depthBias, float slopeScale, float clamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthBias_slopeScale_clamp_), depthBias, slopeScale, clamp); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setDepthClipMode(MTL::DepthClipMode depthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthClipMode_), depthClipMode); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setDepthStencilState(const MTL::DepthStencilState* depthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_), depthStencilState); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setFrontFacingWinding(MTL::Winding frontFacingWindning) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFacingWinding_), frontFacingWindning); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setRenderPipelineState(const MTL::RenderPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_), pipelineState); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setTriangleFillMode(MTL::TriangleFillMode fillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleFillMode_), fillMode); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectRenderCommand::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::clearBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(clearBarrier)); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::concurrentDispatchThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(concurrentDispatchThreadgroups_threadsPerThreadgroup_), threadgroupsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::concurrentDispatchThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(concurrentDispatchThreads_threadsPerThreadgroup_), threadsPerGrid, threadsPerThreadgroup); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBarrier)); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setComputePipelineState(const MTL::ComputePipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setComputePipelineState_), pipelineState); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setImageblockWidth(NS::UInteger width, NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockWidth_height_), width, height); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setKernelBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setKernelBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setKernelBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setStageInRegion(MTL::Region region) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStageInRegion_), region); +} + +_MTL_INLINE void MTL::IndirectComputeCommand::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_atIndex_), length, index); +} diff --git a/dist/include/metal_cpp/Metal/MTLIntersectionFunctionTable.hpp b/dist/include/metal_cpp/Metal/MTLIntersectionFunctionTable.hpp new file mode 100644 index 0000000..436653b --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLIntersectionFunctionTable.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLIntersectionFunctionTable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class FunctionHandle; +class IntersectionFunctionTableDescriptor; +class VisibleFunctionTable; + +_MTL_OPTIONS(NS::UInteger, IntersectionFunctionSignature) { + IntersectionFunctionSignatureNone = 0, + IntersectionFunctionSignatureInstancing = 1, + IntersectionFunctionSignatureTriangleData = 1 << 1, + IntersectionFunctionSignatureWorldSpaceData = 1 << 2, + IntersectionFunctionSignatureInstanceMotion = 1 << 3, + IntersectionFunctionSignaturePrimitiveMotion = 1 << 4, + IntersectionFunctionSignatureExtendedLimits = 1 << 5, + IntersectionFunctionSignatureMaxLevels = 1 << 6, + IntersectionFunctionSignatureCurveData = 1 << 7, + IntersectionFunctionSignatureIntersectionFunctionBuffer = 1 << 8, + IntersectionFunctionSignatureUserData = 1 << 9, +}; + +struct IntersectionFunctionBufferArguments +{ + uint64_t intersectionFunctionBuffer; + uint64_t intersectionFunctionBufferSize; + uint64_t intersectionFunctionStride; +} _MTL_PACKED; + +class IntersectionFunctionTableDescriptor : public NS::Copying +{ +public: + static IntersectionFunctionTableDescriptor* alloc(); + + NS::UInteger functionCount() const; + + IntersectionFunctionTableDescriptor* init(); + + static IntersectionFunctionTableDescriptor* intersectionFunctionTableDescriptor(); + + void setFunctionCount(NS::UInteger functionCount); +}; +class IntersectionFunctionTable : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + void setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + + void setFunction(const MTL::FunctionHandle* function, NS::UInteger index); + void setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range); + + void setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index); + void setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range); + + void setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index); + void setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range); + + void setVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range bufferRange); +}; + +} + +_MTL_INLINE MTL::IntersectionFunctionTableDescriptor* MTL::IntersectionFunctionTableDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLIntersectionFunctionTableDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::IntersectionFunctionTableDescriptor::functionCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionCount)); +} + +_MTL_INLINE MTL::IntersectionFunctionTableDescriptor* MTL::IntersectionFunctionTableDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::IntersectionFunctionTableDescriptor* MTL::IntersectionFunctionTableDescriptor::intersectionFunctionTableDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLIntersectionFunctionTableDescriptor), _MTL_PRIVATE_SEL(intersectionFunctionTableDescriptor)); +} + +_MTL_INLINE void MTL::IntersectionFunctionTableDescriptor::setFunctionCount(NS::UInteger functionCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionCount_), functionCount); +} + +_MTL_INLINE MTL::ResourceID MTL::IntersectionFunctionTable::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setFunction(const MTL::FunctionHandle* function, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunction_atIndex_), function, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_withRange_), functions, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueCurveIntersectionFunctionWithSignature_atIndex_), signature, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueCurveIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueCurveIntersectionFunctionWithSignature_withRange_), signature, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_atIndex_), signature, index); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setOpaqueTriangleIntersectionFunction(MTL::IntersectionFunctionSignature signature, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOpaqueTriangleIntersectionFunctionWithSignature_withRange_), signature, range); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::IntersectionFunctionTable::setVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range bufferRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibleFunctionTables_withBufferRange_), functionTables, bufferRange); +} diff --git a/dist/include/metal_cpp/Metal/MTLLibrary.hpp b/dist/include/metal_cpp/Metal/MTLLibrary.hpp new file mode 100644 index 0000000..44aa3a7 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLLibrary.hpp @@ -0,0 +1,786 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLLibrary.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDataType.hpp" +#include "MTLDefines.hpp" +#include "MTLFunctionDescriptor.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Argument; +class ArgumentEncoder; +class Attribute; +class CompileOptions; +class Device; +class Function; +class FunctionConstant; +class FunctionConstantValues; +class FunctionDescriptor; +class FunctionReflection; +class IntersectionFunctionDescriptor; +class VertexAttribute; +_MTL_ENUM(NS::UInteger, PatchType) { + PatchTypeNone = 0, + PatchTypeTriangle = 1, + PatchTypeQuad = 2, +}; + +_MTL_ENUM(NS::UInteger, FunctionType) { + FunctionTypeVertex = 1, + FunctionTypeFragment = 2, + FunctionTypeKernel = 3, + FunctionTypeVisible = 5, + FunctionTypeIntersection = 6, + FunctionTypeMesh = 7, + FunctionTypeObject = 8, +}; + +_MTL_ENUM(NS::UInteger, LanguageVersion) { + LanguageVersion1_0 = 65536, + LanguageVersion1_1 = 65537, + LanguageVersion1_2 = 65538, + LanguageVersion2_0 = 131072, + LanguageVersion2_1 = 131073, + LanguageVersion2_2 = 131074, + LanguageVersion2_3 = 131075, + LanguageVersion2_4 = 131076, + LanguageVersion3_0 = 196608, + LanguageVersion3_1 = 196609, + LanguageVersion3_2 = 196610, + LanguageVersion4_0 = 262144, +}; + +_MTL_ENUM(NS::Integer, LibraryType) { + LibraryTypeExecutable = 0, + LibraryTypeDynamic = 1, +}; + +_MTL_ENUM(NS::Integer, LibraryOptimizationLevel) { + LibraryOptimizationLevelDefault = 0, + LibraryOptimizationLevelSize = 1, +}; + +_MTL_ENUM(NS::Integer, CompileSymbolVisibility) { + CompileSymbolVisibilityDefault = 0, + CompileSymbolVisibilityHidden = 1, +}; + +_MTL_ENUM(NS::Integer, MathMode) { + MathModeSafe = 0, + MathModeRelaxed = 1, + MathModeFast = 2, +}; + +_MTL_ENUM(NS::Integer, MathFloatingPointFunctions) { + MathFloatingPointFunctionsFast = 0, + MathFloatingPointFunctionsPrecise = 1, +}; + +_MTL_ENUM(NS::UInteger, LibraryError) { + LibraryErrorUnsupported = 1, + LibraryErrorInternal = 2, + LibraryErrorCompileFailure = 3, + LibraryErrorCompileWarning = 4, + LibraryErrorFunctionNotFound = 5, + LibraryErrorFileNotFound = 6, +}; + +using AutoreleasedArgument = MTL::Argument*; +using FunctionCompletionHandlerFunction = std::function; + +class VertexAttribute : public NS::Referencing +{ +public: + [[deprecated("please use isActive instead")]] + bool active() const; + + static VertexAttribute* alloc(); + + NS::UInteger attributeIndex() const; + + DataType attributeType() const; + + VertexAttribute* init(); + + bool isActive() const; + + bool isPatchControlPointData() const; + + bool isPatchData() const; + + NS::String* name() const; + + [[deprecated("please use isPatchControlPointData instead")]] + bool patchControlPointData() const; + + [[deprecated("please use isPatchData instead")]] + bool patchData() const; +}; +class Attribute : public NS::Referencing +{ +public: + [[deprecated("please use isActive instead")]] + bool active() const; + + static Attribute* alloc(); + + NS::UInteger attributeIndex() const; + + DataType attributeType() const; + + Attribute* init(); + + bool isActive() const; + + bool isPatchControlPointData() const; + + bool isPatchData() const; + + NS::String* name() const; + + [[deprecated("please use isPatchControlPointData instead")]] + bool patchControlPointData() const; + + [[deprecated("please use isPatchData instead")]] + bool patchData() const; +}; +class FunctionConstant : public NS::Referencing +{ +public: + static FunctionConstant* alloc(); + + NS::UInteger index() const; + + FunctionConstant* init(); + + NS::String* name() const; + + bool required() const; + + DataType type() const; +}; +class Function : public NS::Referencing +{ +public: + Device* device() const; + + NS::Dictionary* functionConstantsDictionary() const; + + FunctionType functionType() const; + + NS::String* label() const; + + NS::String* name() const; + + ArgumentEncoder* newArgumentEncoder(NS::UInteger bufferIndex); + ArgumentEncoder* newArgumentEncoder(NS::UInteger bufferIndex, const MTL::AutoreleasedArgument* reflection); + + FunctionOptions options() const; + + NS::Integer patchControlPointCount() const; + + PatchType patchType() const; + + void setLabel(const NS::String* label); + + NS::Array* stageInputAttributes() const; + + NS::Array* vertexAttributes() const; +}; +class CompileOptions : public NS::Copying +{ +public: + static CompileOptions* alloc(); + + bool allowReferencingUndefinedSymbols() const; + + CompileSymbolVisibility compileSymbolVisibility() const; + + bool enableLogging() const; + + bool fastMathEnabled() const; + + CompileOptions* init(); + + NS::String* installName() const; + + LanguageVersion languageVersion() const; + + NS::Array* libraries() const; + + LibraryType libraryType() const; + + MathFloatingPointFunctions mathFloatingPointFunctions() const; + + MathMode mathMode() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + LibraryOptimizationLevel optimizationLevel() const; + + NS::Dictionary* preprocessorMacros() const; + + bool preserveInvariance() const; + + Size requiredThreadsPerThreadgroup() const; + + void setAllowReferencingUndefinedSymbols(bool allowReferencingUndefinedSymbols); + + void setCompileSymbolVisibility(MTL::CompileSymbolVisibility compileSymbolVisibility); + + void setEnableLogging(bool enableLogging); + + void setFastMathEnabled(bool fastMathEnabled); + + void setInstallName(const NS::String* installName); + + void setLanguageVersion(MTL::LanguageVersion languageVersion); + + void setLibraries(const NS::Array* libraries); + + void setLibraryType(MTL::LibraryType libraryType); + + void setMathFloatingPointFunctions(MTL::MathFloatingPointFunctions mathFloatingPointFunctions); + + void setMathMode(MTL::MathMode mathMode); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setOptimizationLevel(MTL::LibraryOptimizationLevel optimizationLevel); + + void setPreprocessorMacros(const NS::Dictionary* preprocessorMacros); + + void setPreserveInvariance(bool preserveInvariance); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); +}; +class FunctionReflection : public NS::Referencing +{ +public: + static FunctionReflection* alloc(); + + NS::Array* bindings() const; + + FunctionReflection* init(); +}; +class Library : public NS::Referencing +{ +public: + Device* device() const; + + NS::Array* functionNames() const; + + NS::String* installName() const; + + NS::String* label() const; + + Function* newFunction(const NS::String* functionName); + Function* newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, NS::Error** error); + void newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, void (^completionHandler)(MTL::Function*, NS::Error*)); + void newFunction(const MTL::FunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)); + Function* newFunction(const MTL::FunctionDescriptor* descriptor, NS::Error** error); + void newFunction(const NS::String* pFunctionName, const MTL::FunctionConstantValues* pConstantValues, const MTL::FunctionCompletionHandlerFunction& completionHandler); + void newFunction(const MTL::FunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler); + + void newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)); + Function* newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, NS::Error** error); + void newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler); + + FunctionReflection* reflectionForFunction(const NS::String* functionName); + + void setLabel(const NS::String* label); + + LibraryType type() const; +}; + +} +_MTL_INLINE bool MTL::VertexAttribute::active() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE MTL::VertexAttribute* MTL::VertexAttribute::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexAttribute)); +} + +_MTL_INLINE NS::UInteger MTL::VertexAttribute::attributeIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeIndex)); +} + +_MTL_INLINE MTL::DataType MTL::VertexAttribute::attributeType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeType)); +} + +_MTL_INLINE MTL::VertexAttribute* MTL::VertexAttribute::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::VertexAttribute::isActive() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE bool MTL::VertexAttribute::isPatchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::VertexAttribute::isPatchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE NS::String* MTL::VertexAttribute::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE bool MTL::VertexAttribute::patchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::VertexAttribute::patchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE bool MTL::Attribute::active() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE MTL::Attribute* MTL::Attribute::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAttribute)); +} + +_MTL_INLINE NS::UInteger MTL::Attribute::attributeIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeIndex)); +} + +_MTL_INLINE MTL::DataType MTL::Attribute::attributeType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributeType)); +} + +_MTL_INLINE MTL::Attribute* MTL::Attribute::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::Attribute::isActive() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isActive)); +} + +_MTL_INLINE bool MTL::Attribute::isPatchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::Attribute::isPatchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE NS::String* MTL::Attribute::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE bool MTL::Attribute::patchControlPointData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchControlPointData)); +} + +_MTL_INLINE bool MTL::Attribute::patchData() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isPatchData)); +} + +_MTL_INLINE MTL::FunctionConstant* MTL::FunctionConstant::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionConstant)); +} + +_MTL_INLINE NS::UInteger MTL::FunctionConstant::index() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(index)); +} + +_MTL_INLINE MTL::FunctionConstant* MTL::FunctionConstant::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::FunctionConstant::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE bool MTL::FunctionConstant::required() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(required)); +} + +_MTL_INLINE MTL::DataType MTL::FunctionConstant::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} + +_MTL_INLINE MTL::Device* MTL::Function::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::Dictionary* MTL::Function::functionConstantsDictionary() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionConstantsDictionary)); +} + +_MTL_INLINE MTL::FunctionType MTL::Function::functionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionType)); +} + +_MTL_INLINE NS::String* MTL::Function::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::String* MTL::Function::name() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(name)); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Function::newArgumentEncoder(NS::UInteger bufferIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithBufferIndex_), bufferIndex); +} + +_MTL_INLINE MTL::ArgumentEncoder* MTL::Function::newArgumentEncoder(NS::UInteger bufferIndex, const MTL::AutoreleasedArgument* reflection) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newArgumentEncoderWithBufferIndex_reflection_), bufferIndex, reflection); +} + +_MTL_INLINE MTL::FunctionOptions MTL::Function::options() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(options)); +} + +_MTL_INLINE NS::Integer MTL::Function::patchControlPointCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(patchControlPointCount)); +} + +_MTL_INLINE MTL::PatchType MTL::Function::patchType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(patchType)); +} + +_MTL_INLINE void MTL::Function::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE NS::Array* MTL::Function::stageInputAttributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stageInputAttributes)); +} + +_MTL_INLINE NS::Array* MTL::Function::vertexAttributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexAttributes)); +} + +_MTL_INLINE MTL::CompileOptions* MTL::CompileOptions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLCompileOptions)); +} + +_MTL_INLINE bool MTL::CompileOptions::allowReferencingUndefinedSymbols() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowReferencingUndefinedSymbols)); +} + +_MTL_INLINE MTL::CompileSymbolVisibility MTL::CompileOptions::compileSymbolVisibility() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compileSymbolVisibility)); +} + +_MTL_INLINE bool MTL::CompileOptions::enableLogging() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(enableLogging)); +} + +_MTL_INLINE bool MTL::CompileOptions::fastMathEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fastMathEnabled)); +} + +_MTL_INLINE MTL::CompileOptions* MTL::CompileOptions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::CompileOptions::installName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(installName)); +} + +_MTL_INLINE MTL::LanguageVersion MTL::CompileOptions::languageVersion() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(languageVersion)); +} + +_MTL_INLINE NS::Array* MTL::CompileOptions::libraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(libraries)); +} + +_MTL_INLINE MTL::LibraryType MTL::CompileOptions::libraryType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(libraryType)); +} + +_MTL_INLINE MTL::MathFloatingPointFunctions MTL::CompileOptions::mathFloatingPointFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mathFloatingPointFunctions)); +} + +_MTL_INLINE MTL::MathMode MTL::CompileOptions::mathMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mathMode)); +} + +_MTL_INLINE NS::UInteger MTL::CompileOptions::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE MTL::LibraryOptimizationLevel MTL::CompileOptions::optimizationLevel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(optimizationLevel)); +} + +_MTL_INLINE NS::Dictionary* MTL::CompileOptions::preprocessorMacros() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preprocessorMacros)); +} + +_MTL_INLINE bool MTL::CompileOptions::preserveInvariance() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preserveInvariance)); +} + +_MTL_INLINE MTL::Size MTL::CompileOptions::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL::CompileOptions::setAllowReferencingUndefinedSymbols(bool allowReferencingUndefinedSymbols) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowReferencingUndefinedSymbols_), allowReferencingUndefinedSymbols); +} + +_MTL_INLINE void MTL::CompileOptions::setCompileSymbolVisibility(MTL::CompileSymbolVisibility compileSymbolVisibility) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCompileSymbolVisibility_), compileSymbolVisibility); +} + +_MTL_INLINE void MTL::CompileOptions::setEnableLogging(bool enableLogging) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEnableLogging_), enableLogging); +} + +_MTL_INLINE void MTL::CompileOptions::setFastMathEnabled(bool fastMathEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFastMathEnabled_), fastMathEnabled); +} + +_MTL_INLINE void MTL::CompileOptions::setInstallName(const NS::String* installName) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInstallName_), installName); +} + +_MTL_INLINE void MTL::CompileOptions::setLanguageVersion(MTL::LanguageVersion languageVersion) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLanguageVersion_), languageVersion); +} + +_MTL_INLINE void MTL::CompileOptions::setLibraries(const NS::Array* libraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLibraries_), libraries); +} + +_MTL_INLINE void MTL::CompileOptions::setLibraryType(MTL::LibraryType libraryType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLibraryType_), libraryType); +} + +_MTL_INLINE void MTL::CompileOptions::setMathFloatingPointFunctions(MTL::MathFloatingPointFunctions mathFloatingPointFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMathFloatingPointFunctions_), mathFloatingPointFunctions); +} + +_MTL_INLINE void MTL::CompileOptions::setMathMode(MTL::MathMode mathMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMathMode_), mathMode); +} + +_MTL_INLINE void MTL::CompileOptions::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::CompileOptions::setOptimizationLevel(MTL::LibraryOptimizationLevel optimizationLevel) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOptimizationLevel_), optimizationLevel); +} + +_MTL_INLINE void MTL::CompileOptions::setPreprocessorMacros(const NS::Dictionary* preprocessorMacros) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreprocessorMacros_), preprocessorMacros); +} + +_MTL_INLINE void MTL::CompileOptions::setPreserveInvariance(bool preserveInvariance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreserveInvariance_), preserveInvariance); +} + +_MTL_INLINE void MTL::CompileOptions::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE MTL::FunctionReflection* MTL::FunctionReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLFunctionReflection)); +} + +_MTL_INLINE NS::Array* MTL::FunctionReflection::bindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bindings)); +} + +_MTL_INLINE MTL::FunctionReflection* MTL::FunctionReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Device* MTL::Library::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::Array* MTL::Library::functionNames() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionNames)); +} + +_MTL_INLINE NS::String* MTL::Library::installName() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(installName)); +} + +_MTL_INLINE NS::String* MTL::Library::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::Function* MTL::Library::newFunction(const NS::String* functionName) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithName_), functionName); +} + +_MTL_INLINE MTL::Function* MTL::Library::newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithName_constantValues_error_), name, constantValues, error); +} + +_MTL_INLINE void MTL::Library::newFunction(const NS::String* name, const MTL::FunctionConstantValues* constantValues, void (^completionHandler)(MTL::Function*, NS::Error*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithName_constantValues_completionHandler_), name, constantValues, completionHandler); +} + +_MTL_INLINE void MTL::Library::newFunction(const MTL::FunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL::Function* MTL::Library::newFunction(const MTL::FunctionDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newFunctionWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::Library::newFunction(const NS::String* pFunctionName, const MTL::FunctionConstantValues* pConstantValues, const MTL::FunctionCompletionHandlerFunction& completionHandler) +{ + __block MTL::FunctionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newFunction(pFunctionName, pConstantValues, ^(MTL::Function* pFunction, NS::Error* pError) { blockCompletionHandler(pFunction, pError); }); +} + +_MTL_INLINE void MTL::Library::newFunction(const MTL::FunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler) +{ + __block MTL::FunctionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newFunction(pDescriptor, ^(MTL::Function* pFunction, NS::Error* pError) { blockCompletionHandler(pFunction, pError); }); +} + +_MTL_INLINE void MTL::Library::newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, void (^completionHandler)(MTL::Function*, NS::Error*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionWithDescriptor_completionHandler_), descriptor, completionHandler); +} + +_MTL_INLINE MTL::Function* MTL::Library::newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* descriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionWithDescriptor_error_), descriptor, error); +} + +_MTL_INLINE void MTL::Library::newIntersectionFunction(const MTL::IntersectionFunctionDescriptor* pDescriptor, const MTL::FunctionCompletionHandlerFunction& completionHandler) +{ + __block MTL::FunctionCompletionHandlerFunction blockCompletionHandler = completionHandler; + newIntersectionFunction(pDescriptor, ^(MTL::Function* pFunction, NS::Error* pError) { blockCompletionHandler(pFunction, pError); }); +} + +_MTL_INLINE MTL::FunctionReflection* MTL::Library::reflectionForFunction(const NS::String* functionName) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflectionForFunctionWithName_), functionName); +} + +_MTL_INLINE void MTL::Library::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE MTL::LibraryType MTL::Library::type() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(type)); +} diff --git a/dist/include/metal_cpp/Metal/MTLLinkedFunctions.hpp b/dist/include/metal_cpp/Metal/MTLLinkedFunctions.hpp new file mode 100644 index 0000000..4b1bd95 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLLinkedFunctions.hpp @@ -0,0 +1,110 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLLinkedFunctions.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ + +class LinkedFunctions : public NS::Copying +{ +public: + static LinkedFunctions* alloc(); + + NS::Array* binaryFunctions() const; + NS::Array* functions() const; + + NS::Dictionary* groups() const; + + LinkedFunctions* init(); + + static LinkedFunctions* linkedFunctions(); + + NS::Array* privateFunctions() const; + + void setBinaryFunctions(const NS::Array* binaryFunctions); + + void setFunctions(const NS::Array* functions); + + void setGroups(const NS::Dictionary* groups); + + void setPrivateFunctions(const NS::Array* privateFunctions); +}; + +} +_MTL_INLINE MTL::LinkedFunctions* MTL::LinkedFunctions::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLLinkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::LinkedFunctions::binaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL::LinkedFunctions::functions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functions)); +} + +_MTL_INLINE NS::Dictionary* MTL::LinkedFunctions::groups() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(groups)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::LinkedFunctions::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::LinkedFunctions::linkedFunctions() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLLinkedFunctions), _MTL_PRIVATE_SEL(linkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::LinkedFunctions::privateFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(privateFunctions)); +} + +_MTL_INLINE void MTL::LinkedFunctions::setBinaryFunctions(const NS::Array* binaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryFunctions_), binaryFunctions); +} + +_MTL_INLINE void MTL::LinkedFunctions::setFunctions(const NS::Array* functions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_), functions); +} + +_MTL_INLINE void MTL::LinkedFunctions::setGroups(const NS::Dictionary* groups) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setGroups_), groups); +} + +_MTL_INLINE void MTL::LinkedFunctions::setPrivateFunctions(const NS::Array* privateFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPrivateFunctions_), privateFunctions); +} diff --git a/dist/include/metal_cpp/Metal/MTLLogState.hpp b/dist/include/metal_cpp/Metal/MTLLogState.hpp new file mode 100644 index 0000000..b802adf --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLLogState.hpp @@ -0,0 +1,111 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLLogState.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class LogStateDescriptor; +_MTL_ENUM(NS::Integer, LogLevel) { + LogLevelUndefined = 0, + LogLevelDebug = 1, + LogLevelInfo = 2, + LogLevelNotice = 3, + LogLevelError = 4, + LogLevelFault = 5, +}; + +_MTL_ENUM(NS::UInteger, LogStateError) { + LogStateErrorInvalidSize = 1, + LogStateErrorInvalid = 2, +}; + +using LogHandlerFunction = std::function; + +_MTL_CONST(NS::ErrorDomain, LogStateErrorDomain); +class LogState : public NS::Referencing +{ +public: + void addLogHandler(void (^block)(NS::String*, NS::String*, MTL::LogLevel, NS::String*)); + void addLogHandler(const MTL::LogHandlerFunction& handler); +}; +class LogStateDescriptor : public NS::Copying +{ +public: + static LogStateDescriptor* alloc(); + + NS::Integer bufferSize() const; + + LogStateDescriptor* init(); + + LogLevel level() const; + + void setBufferSize(NS::Integer bufferSize); + + void setLevel(MTL::LogLevel level); +}; + +} +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, LogStateErrorDomain); +_MTL_INLINE void MTL::LogState::addLogHandler(void (^block)(NS::String*, NS::String*, MTL::LogLevel, NS::String*)) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addLogHandler_), block); +} + +_MTL_INLINE void MTL::LogState::addLogHandler(const MTL::LogHandlerFunction& handler) +{ + __block LogHandlerFunction function = handler; + addLogHandler(^void(NS::String* subsystem, NS::String* category, MTL::LogLevel logLevel, NS::String* message) { function(subsystem, category, logLevel, message); }); +} + +_MTL_INLINE MTL::LogStateDescriptor* MTL::LogStateDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLLogStateDescriptor)); +} + +_MTL_INLINE NS::Integer MTL::LogStateDescriptor::bufferSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferSize)); +} + +_MTL_INLINE MTL::LogStateDescriptor* MTL::LogStateDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::LogLevel MTL::LogStateDescriptor::level() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(level)); +} + +_MTL_INLINE void MTL::LogStateDescriptor::setBufferSize(NS::Integer bufferSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferSize_), bufferSize); +} + +_MTL_INLINE void MTL::LogStateDescriptor::setLevel(MTL::LogLevel level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLevel_), level); +} diff --git a/dist/include/metal_cpp/Metal/MTLParallelRenderCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLParallelRenderCommandEncoder.hpp new file mode 100644 index 0000000..8c34512 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLParallelRenderCommandEncoder.hpp @@ -0,0 +1,83 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLParallelRenderCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPass.hpp" + +namespace MTL +{ +class RenderCommandEncoder; + +class ParallelRenderCommandEncoder : public NS::Referencing +{ +public: + RenderCommandEncoder* renderCommandEncoder(); + + void setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex); + void setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex); + + void setDepthStoreAction(MTL::StoreAction storeAction); + void setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setStencilStoreAction(MTL::StoreAction storeAction); + void setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions); +}; + +} +_MTL_INLINE MTL::RenderCommandEncoder* MTL::ParallelRenderCommandEncoder::renderCommandEncoder() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderCommandEncoder)); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreAction_atIndex_), storeAction, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreActionOptions_atIndex_), storeActionOptions, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setDepthStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setStencilStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::ParallelRenderCommandEncoder::setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreActionOptions_), storeActionOptions); +} diff --git a/dist/include/metal_cpp/Metal/MTLPipeline.hpp b/dist/include/metal_cpp/Metal/MTLPipeline.hpp new file mode 100644 index 0000000..930bb7e --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLPipeline.hpp @@ -0,0 +1,104 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class PipelineBufferDescriptor; +class PipelineBufferDescriptorArray; +_MTL_ENUM(NS::UInteger, Mutability) { + MutabilityDefault = 0, + MutabilityMutable = 1, + MutabilityImmutable = 2, +}; + +_MTL_ENUM(NS::Integer, ShaderValidation) { + ShaderValidationDefault = 0, + ShaderValidationEnabled = 1, + ShaderValidationDisabled = 2, +}; + +class PipelineBufferDescriptor : public NS::Copying +{ +public: + static PipelineBufferDescriptor* alloc(); + + PipelineBufferDescriptor* init(); + + Mutability mutability() const; + void setMutability(MTL::Mutability mutability); +}; +class PipelineBufferDescriptorArray : public NS::Referencing +{ +public: + static PipelineBufferDescriptorArray* alloc(); + + PipelineBufferDescriptorArray* init(); + + PipelineBufferDescriptor* object(NS::UInteger bufferIndex); + void setObject(const MTL::PipelineBufferDescriptor* buffer, NS::UInteger bufferIndex); +}; + +} +_MTL_INLINE MTL::PipelineBufferDescriptor* MTL::PipelineBufferDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPipelineBufferDescriptor)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptor* MTL::PipelineBufferDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::Mutability MTL::PipelineBufferDescriptor::mutability() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mutability)); +} + +_MTL_INLINE void MTL::PipelineBufferDescriptor::setMutability(MTL::Mutability mutability) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMutability_), mutability); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::PipelineBufferDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLPipelineBufferDescriptorArray)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::PipelineBufferDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PipelineBufferDescriptor* MTL::PipelineBufferDescriptorArray::object(NS::UInteger bufferIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), bufferIndex); +} + +_MTL_INLINE void MTL::PipelineBufferDescriptorArray::setObject(const MTL::PipelineBufferDescriptor* buffer, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), buffer, bufferIndex); +} diff --git a/dist/include/metal_cpp/Metal/MTLPixelFormat.hpp b/dist/include/metal_cpp/Metal/MTLPixelFormat.hpp new file mode 100644 index 0000000..6d5d886 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLPixelFormat.hpp @@ -0,0 +1,173 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLPixelFormat.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +_MTL_ENUM(NS::UInteger, PixelFormat) { + PixelFormatInvalid = 0, + PixelFormatA8Unorm = 1, + PixelFormatR8Unorm = 10, + PixelFormatR8Unorm_sRGB = 11, + PixelFormatR8Snorm = 12, + PixelFormatR8Uint = 13, + PixelFormatR8Sint = 14, + PixelFormatR16Unorm = 20, + PixelFormatR16Snorm = 22, + PixelFormatR16Uint = 23, + PixelFormatR16Sint = 24, + PixelFormatR16Float = 25, + PixelFormatRG8Unorm = 30, + PixelFormatRG8Unorm_sRGB = 31, + PixelFormatRG8Snorm = 32, + PixelFormatRG8Uint = 33, + PixelFormatRG8Sint = 34, + PixelFormatB5G6R5Unorm = 40, + PixelFormatA1BGR5Unorm = 41, + PixelFormatABGR4Unorm = 42, + PixelFormatBGR5A1Unorm = 43, + PixelFormatR32Uint = 53, + PixelFormatR32Sint = 54, + PixelFormatR32Float = 55, + PixelFormatRG16Unorm = 60, + PixelFormatRG16Snorm = 62, + PixelFormatRG16Uint = 63, + PixelFormatRG16Sint = 64, + PixelFormatRG16Float = 65, + PixelFormatRGBA8Unorm = 70, + PixelFormatRGBA8Unorm_sRGB = 71, + PixelFormatRGBA8Snorm = 72, + PixelFormatRGBA8Uint = 73, + PixelFormatRGBA8Sint = 74, + PixelFormatBGRA8Unorm = 80, + PixelFormatBGRA8Unorm_sRGB = 81, + PixelFormatRGB10A2Unorm = 90, + PixelFormatRGB10A2Uint = 91, + PixelFormatRG11B10Float = 92, + PixelFormatRGB9E5Float = 93, + PixelFormatBGR10A2Unorm = 94, + PixelFormatBGR10_XR = 554, + PixelFormatBGR10_XR_sRGB = 555, + PixelFormatRG32Uint = 103, + PixelFormatRG32Sint = 104, + PixelFormatRG32Float = 105, + PixelFormatRGBA16Unorm = 110, + PixelFormatRGBA16Snorm = 112, + PixelFormatRGBA16Uint = 113, + PixelFormatRGBA16Sint = 114, + PixelFormatRGBA16Float = 115, + PixelFormatBGRA10_XR = 552, + PixelFormatBGRA10_XR_sRGB = 553, + PixelFormatRGBA32Uint = 123, + PixelFormatRGBA32Sint = 124, + PixelFormatRGBA32Float = 125, + PixelFormatBC1_RGBA = 130, + PixelFormatBC1_RGBA_sRGB = 131, + PixelFormatBC2_RGBA = 132, + PixelFormatBC2_RGBA_sRGB = 133, + PixelFormatBC3_RGBA = 134, + PixelFormatBC3_RGBA_sRGB = 135, + PixelFormatBC4_RUnorm = 140, + PixelFormatBC4_RSnorm = 141, + PixelFormatBC5_RGUnorm = 142, + PixelFormatBC5_RGSnorm = 143, + PixelFormatBC6H_RGBFloat = 150, + PixelFormatBC6H_RGBUfloat = 151, + PixelFormatBC7_RGBAUnorm = 152, + PixelFormatBC7_RGBAUnorm_sRGB = 153, + PixelFormatPVRTC_RGB_2BPP = 160, + PixelFormatPVRTC_RGB_2BPP_sRGB = 161, + PixelFormatPVRTC_RGB_4BPP = 162, + PixelFormatPVRTC_RGB_4BPP_sRGB = 163, + PixelFormatPVRTC_RGBA_2BPP = 164, + PixelFormatPVRTC_RGBA_2BPP_sRGB = 165, + PixelFormatPVRTC_RGBA_4BPP = 166, + PixelFormatPVRTC_RGBA_4BPP_sRGB = 167, + PixelFormatEAC_R11Unorm = 170, + PixelFormatEAC_R11Snorm = 172, + PixelFormatEAC_RG11Unorm = 174, + PixelFormatEAC_RG11Snorm = 176, + PixelFormatEAC_RGBA8 = 178, + PixelFormatEAC_RGBA8_sRGB = 179, + PixelFormatETC2_RGB8 = 180, + PixelFormatETC2_RGB8_sRGB = 181, + PixelFormatETC2_RGB8A1 = 182, + PixelFormatETC2_RGB8A1_sRGB = 183, + PixelFormatASTC_4x4_sRGB = 186, + PixelFormatASTC_5x4_sRGB = 187, + PixelFormatASTC_5x5_sRGB = 188, + PixelFormatASTC_6x5_sRGB = 189, + PixelFormatASTC_6x6_sRGB = 190, + PixelFormatASTC_8x5_sRGB = 192, + PixelFormatASTC_8x6_sRGB = 193, + PixelFormatASTC_8x8_sRGB = 194, + PixelFormatASTC_10x5_sRGB = 195, + PixelFormatASTC_10x6_sRGB = 196, + PixelFormatASTC_10x8_sRGB = 197, + PixelFormatASTC_10x10_sRGB = 198, + PixelFormatASTC_12x10_sRGB = 199, + PixelFormatASTC_12x12_sRGB = 200, + PixelFormatASTC_4x4_LDR = 204, + PixelFormatASTC_5x4_LDR = 205, + PixelFormatASTC_5x5_LDR = 206, + PixelFormatASTC_6x5_LDR = 207, + PixelFormatASTC_6x6_LDR = 208, + PixelFormatASTC_8x5_LDR = 210, + PixelFormatASTC_8x6_LDR = 211, + PixelFormatASTC_8x8_LDR = 212, + PixelFormatASTC_10x5_LDR = 213, + PixelFormatASTC_10x6_LDR = 214, + PixelFormatASTC_10x8_LDR = 215, + PixelFormatASTC_10x10_LDR = 216, + PixelFormatASTC_12x10_LDR = 217, + PixelFormatASTC_12x12_LDR = 218, + PixelFormatASTC_4x4_HDR = 222, + PixelFormatASTC_5x4_HDR = 223, + PixelFormatASTC_5x5_HDR = 224, + PixelFormatASTC_6x5_HDR = 225, + PixelFormatASTC_6x6_HDR = 226, + PixelFormatASTC_8x5_HDR = 228, + PixelFormatASTC_8x6_HDR = 229, + PixelFormatASTC_8x8_HDR = 230, + PixelFormatASTC_10x5_HDR = 231, + PixelFormatASTC_10x6_HDR = 232, + PixelFormatASTC_10x8_HDR = 233, + PixelFormatASTC_10x10_HDR = 234, + PixelFormatASTC_12x10_HDR = 235, + PixelFormatASTC_12x12_HDR = 236, + PixelFormatGBGR422 = 240, + PixelFormatBGRG422 = 241, + PixelFormatDepth16Unorm = 250, + PixelFormatDepth32Float = 252, + PixelFormatStencil8 = 253, + PixelFormatDepth24Unorm_Stencil8 = 255, + PixelFormatDepth32Float_Stencil8 = 260, + PixelFormatX32_Stencil8 = 261, + PixelFormatX24_Stencil8 = 262, + PixelFormatUnspecialized = 263, +}; + +} diff --git a/dist/include/metal_cpp/Metal/MTLPrivate.hpp b/dist/include/metal_cpp/Metal/MTLPrivate.hpp new file mode 100644 index 0000000..41bcaa5 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLPrivate.hpp @@ -0,0 +1,156 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLDefines.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTL_PRIVATE_CLS(symbol) (MTL::Private::Class::s_k##symbol) +#define _MTL_PRIVATE_SEL(accessor) (MTL::Private::Selector::s_k##accessor) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined(MTL_PRIVATE_IMPLEMENTATION) + +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _MTL_PRIVATE_VISIBILITY __attribute__((visibility("hidden"))) +#else +#define _MTL_PRIVATE_VISIBILITY __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _MTL_PRIVATE_IMPORT __attribute__((weak_import)) + +#ifdef __OBJC__ +#define _MTL_PRIVATE_OBJC_LOOKUP_CLASS(symbol) ((__bridge void*)objc_lookUpClass(#symbol)) +#define _MTL_PRIVATE_OBJC_GET_PROTOCOL(symbol) ((__bridge void*)objc_getProtocol(#symbol)) +#else +#define _MTL_PRIVATE_OBJC_LOOKUP_CLASS(symbol) objc_lookUpClass(#symbol) +#define _MTL_PRIVATE_OBJC_GET_PROTOCOL(symbol) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _MTL_PRIVATE_DEF_CLS(symbol) void* s_k##symbol _MTL_PRIVATE_VISIBILITY = _MTL_PRIVATE_OBJC_LOOKUP_CLASS(symbol) +#define _MTL_PRIVATE_DEF_PRO(symbol) void* s_k##symbol _MTL_PRIVATE_VISIBILITY = _MTL_PRIVATE_OBJC_GET_PROTOCOL(symbol) +#define _MTL_PRIVATE_DEF_SEL(accessor, symbol) SEL s_k##accessor _MTL_PRIVATE_VISIBILITY = sel_registerName(symbol) + +#include +#define MTL_DEF_FUNC( name, signature ) \ + using Fn##name = signature; \ + Fn##name name = reinterpret_cast< Fn##name >( dlsym( RTLD_DEFAULT, #name ) ) + +namespace MTL::Private +{ + template + inline _Type const LoadSymbol(const char* pSymbol) + { + const _Type* pAddress = static_cast<_Type*>(dlsym(RTLD_DEFAULT, pSymbol)); + + return pAddress ? *pAddress : nullptr; + } +} // MTL::Private + +#if defined(__MAC_26_0) || defined(__IPHONE_26_0) || defined(__TVOS_26_0) + +#define _MTL_PRIVATE_DEF_STR(type, symbol) \ + _MTL_EXTERN type const MTL##symbol _MTL_PRIVATE_IMPORT; \ + type const MTL::symbol = (nullptr != &MTL##symbol) ? MTL##symbol : nullptr + +#define _MTL_PRIVATE_DEF_CONST(type, symbol) \ + _MTL_EXTERN type const MTL##symbol _MTL_PRIVATE_IMPORT; \ + type const MTL::symbol = (nullptr != &MTL##symbol) ? MTL##symbol : nullptr + +#define _MTL_PRIVATE_DEF_WEAK_CONST(type, symbol) \ + _MTL_EXTERN type const MTL##symbol; \ + type const MTL::symbol = MTL::Private::LoadSymbol("MTL" #symbol) + +#else + +#define _MTL_PRIVATE_DEF_STR(type, symbol) \ + _MTL_EXTERN type const MTL##symbol; \ + type const MTL::symbol = MTL::Private::LoadSymbol("MTL" #symbol) + +#define _MTL_PRIVATE_DEF_CONST(type, symbol) \ + _MTL_EXTERN type const MTL##symbol; \ + type const MTL::symbol = MTL::Private::LoadSymbol("MTL" #symbol) + +#define _MTL_PRIVATE_DEF_WEAK_CONST(type, symbol) _MTL_PRIVATE_DEF_CONST(type, symbol) + +#endif + +#else + +#define _MTL_PRIVATE_DEF_CLS(symbol) extern void* s_k##symbol +#define _MTL_PRIVATE_DEF_PRO(symbol) extern void* s_k##symbol +#define _MTL_PRIVATE_DEF_SEL(accessor, symbol) extern SEL s_k##accessor +#define _MTL_PRIVATE_DEF_STR(type, symbol) extern type const MTL::symbol +#define _MTL_PRIVATE_DEF_CONST(type, symbol) extern type const MTL::symbol +#define _MTL_PRIVATE_DEF_WEAK_CONST(type, symbol) extern type const MTL::symbol + +#endif // MTL_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +namespace Private +{ + namespace Class + { + + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +namespace Private +{ + namespace Protocol + { + + } // Protocol +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL +{ +namespace Private +{ + namespace Selector + { + + _MTL_PRIVATE_DEF_SEL(beginScope, + "beginScope"); + _MTL_PRIVATE_DEF_SEL(endScope, + "endScope"); + } // Class +} // Private +} // MTL + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/Metal/MTLRasterizationRate.hpp b/dist/include/metal_cpp/Metal/MTLRasterizationRate.hpp new file mode 100644 index 0000000..b2804fa --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLRasterizationRate.hpp @@ -0,0 +1,337 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLRasterizationRate.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLDevice.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class Device; +class RasterizationRateLayerArray; +class RasterizationRateLayerDescriptor; +class RasterizationRateMapDescriptor; +class RasterizationRateSampleArray; + +class RasterizationRateSampleArray : public NS::Referencing +{ +public: + static RasterizationRateSampleArray* alloc(); + + RasterizationRateSampleArray* init(); + + NS::Number* object(NS::UInteger index); + void setObject(const NS::Number* value, NS::UInteger index); +}; +class RasterizationRateLayerDescriptor : public NS::Copying +{ +public: + static RasterizationRateLayerDescriptor* alloc(); + + RasterizationRateSampleArray* horizontal() const; + float* horizontalSampleStorage() const; + + RasterizationRateLayerDescriptor* init(); + RasterizationRateLayerDescriptor* init(MTL::Size sampleCount); + RasterizationRateLayerDescriptor* init(MTL::Size sampleCount, const float* horizontal, const float* vertical); + + Size maxSampleCount() const; + Size sampleCount() const; + void setSampleCount(MTL::Size sampleCount); + + RasterizationRateSampleArray* vertical() const; + float* verticalSampleStorage() const; +}; +class RasterizationRateLayerArray : public NS::Referencing +{ +public: + static RasterizationRateLayerArray* alloc(); + + RasterizationRateLayerArray* init(); + + RasterizationRateLayerDescriptor* object(NS::UInteger layerIndex); + void setObject(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex); +}; +class RasterizationRateMapDescriptor : public NS::Copying +{ +public: + static RasterizationRateMapDescriptor* alloc(); + + RasterizationRateMapDescriptor* init(); + + NS::String* label() const; + + RasterizationRateLayerDescriptor* layer(NS::UInteger layerIndex); + NS::UInteger layerCount() const; + + RasterizationRateLayerArray* layers() const; + + static RasterizationRateMapDescriptor* rasterizationRateMapDescriptor(MTL::Size screenSize); + static RasterizationRateMapDescriptor* rasterizationRateMapDescriptor(MTL::Size screenSize, const MTL::RasterizationRateLayerDescriptor* layer); + static RasterizationRateMapDescriptor* rasterizationRateMapDescriptor(MTL::Size screenSize, NS::UInteger layerCount, const MTL::RasterizationRateLayerDescriptor* const* layers); + + Size screenSize() const; + + void setLabel(const NS::String* label); + + void setLayer(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex); + + void setScreenSize(MTL::Size screenSize); +}; +class RasterizationRateMap : public NS::Referencing +{ +public: + void copyParameterDataToBuffer(const MTL::Buffer* buffer, NS::UInteger offset); + + Device* device() const; + + NS::String* label() const; + + NS::UInteger layerCount() const; + + Coordinate2D mapPhysicalToScreenCoordinates(MTL::Coordinate2D physicalCoordinates, NS::UInteger layerIndex); + + Coordinate2D mapScreenToPhysicalCoordinates(MTL::Coordinate2D screenCoordinates, NS::UInteger layerIndex); + + SizeAndAlign parameterBufferSizeAndAlign() const; + + Size physicalGranularity() const; + + Size physicalSize(NS::UInteger layerIndex); + + Size screenSize() const; +}; + +} +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateSampleArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateSampleArray)); +} + +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateSampleArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Number* MTL::RasterizationRateSampleArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::RasterizationRateSampleArray::setObject(const NS::Number* value, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), value, index); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateLayerDescriptor)); +} + +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateLayerDescriptor::horizontal() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(horizontal)); +} + +_MTL_INLINE float* MTL::RasterizationRateLayerDescriptor::horizontalSampleStorage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(horizontalSampleStorage)); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::init(MTL::Size sampleCount) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithSampleCount_), sampleCount); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerDescriptor::init(MTL::Size sampleCount, const float* horizontal, const float* vertical) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithSampleCount_horizontal_vertical_), sampleCount, horizontal, vertical); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateLayerDescriptor::maxSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxSampleCount)); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateLayerDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::RasterizationRateLayerDescriptor::setSampleCount(MTL::Size sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE MTL::RasterizationRateSampleArray* MTL::RasterizationRateLayerDescriptor::vertical() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertical)); +} + +_MTL_INLINE float* MTL::RasterizationRateLayerDescriptor::verticalSampleStorage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(verticalSampleStorage)); +} + +_MTL_INLINE MTL::RasterizationRateLayerArray* MTL::RasterizationRateLayerArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateLayerArray)); +} + +_MTL_INLINE MTL::RasterizationRateLayerArray* MTL::RasterizationRateLayerArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateLayerArray::object(NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), layerIndex); +} + +_MTL_INLINE void MTL::RasterizationRateLayerArray::setObject(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), layer, layerIndex); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor)); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::RasterizationRateMapDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::RasterizationRateLayerDescriptor* MTL::RasterizationRateMapDescriptor::layer(NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layerAtIndex_), layerIndex); +} + +_MTL_INLINE NS::UInteger MTL::RasterizationRateMapDescriptor::layerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layerCount)); +} + +_MTL_INLINE MTL::RasterizationRateLayerArray* MTL::RasterizationRateMapDescriptor::layers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layers)); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::rasterizationRateMapDescriptor(MTL::Size screenSize) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor), _MTL_PRIVATE_SEL(rasterizationRateMapDescriptorWithScreenSize_), screenSize); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::rasterizationRateMapDescriptor(MTL::Size screenSize, const MTL::RasterizationRateLayerDescriptor* layer) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor), _MTL_PRIVATE_SEL(rasterizationRateMapDescriptorWithScreenSize_layer_), screenSize, layer); +} + +_MTL_INLINE MTL::RasterizationRateMapDescriptor* MTL::RasterizationRateMapDescriptor::rasterizationRateMapDescriptor(MTL::Size screenSize, NS::UInteger layerCount, const MTL::RasterizationRateLayerDescriptor* const* layers) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRasterizationRateMapDescriptor), _MTL_PRIVATE_SEL(rasterizationRateMapDescriptorWithScreenSize_layerCount_layers_), screenSize, layerCount, layers); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMapDescriptor::screenSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(screenSize)); +} + +_MTL_INLINE void MTL::RasterizationRateMapDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::RasterizationRateMapDescriptor::setLayer(const MTL::RasterizationRateLayerDescriptor* layer, NS::UInteger layerIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLayer_atIndex_), layer, layerIndex); +} + +_MTL_INLINE void MTL::RasterizationRateMapDescriptor::setScreenSize(MTL::Size screenSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScreenSize_), screenSize); +} + +_MTL_INLINE void MTL::RasterizationRateMap::copyParameterDataToBuffer(const MTL::Buffer* buffer, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(copyParameterDataToBuffer_offset_), buffer, offset); +} + +_MTL_INLINE MTL::Device* MTL::RasterizationRateMap::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::RasterizationRateMap::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::RasterizationRateMap::layerCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layerCount)); +} + +_MTL_INLINE MTL::Coordinate2D MTL::RasterizationRateMap::mapPhysicalToScreenCoordinates(MTL::Coordinate2D physicalCoordinates, NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mapPhysicalToScreenCoordinates_forLayer_), physicalCoordinates, layerIndex); +} + +_MTL_INLINE MTL::Coordinate2D MTL::RasterizationRateMap::mapScreenToPhysicalCoordinates(MTL::Coordinate2D screenCoordinates, NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mapScreenToPhysicalCoordinates_forLayer_), screenCoordinates, layerIndex); +} + +_MTL_INLINE MTL::SizeAndAlign MTL::RasterizationRateMap::parameterBufferSizeAndAlign() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parameterBufferSizeAndAlign)); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMap::physicalGranularity() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(physicalGranularity)); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMap::physicalSize(NS::UInteger layerIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(physicalSizeForLayer_), layerIndex); +} + +_MTL_INLINE MTL::Size MTL::RasterizationRateMap::screenSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(screenSize)); +} diff --git a/dist/include/metal_cpp/Metal/MTLRenderCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLRenderCommandEncoder.hpp new file mode 100644 index 0000000..b2667b7 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLRenderCommandEncoder.hpp @@ -0,0 +1,1019 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceStatePass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLArgument.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderPass.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class AccelerationStructure; +class Buffer; +class CounterSampleBuffer; +class DepthStencilState; +class Fence; +class Heap; +class IndirectCommandBuffer; +class IntersectionFunctionTable; +class LogicalToPhysicalColorAttachmentMap; +class RenderPipelineState; +class Resource; +class SamplerState; +struct ScissorRect; +class Texture; +struct VertexAmplificationViewMapping; +struct Viewport; +class VisibleFunctionTable; +_MTL_ENUM(NS::UInteger, PrimitiveType) { + PrimitiveTypePoint = 0, + PrimitiveTypeLine = 1, + PrimitiveTypeLineStrip = 2, + PrimitiveTypeTriangle = 3, + PrimitiveTypeTriangleStrip = 4, +}; + +_MTL_ENUM(NS::UInteger, VisibilityResultMode) { + VisibilityResultModeDisabled = 0, + VisibilityResultModeBoolean = 1, + VisibilityResultModeCounting = 2, +}; + +_MTL_ENUM(NS::UInteger, CullMode) { + CullModeNone = 0, + CullModeFront = 1, + CullModeBack = 2, +}; + +_MTL_ENUM(NS::UInteger, Winding) { + WindingClockwise = 0, + WindingCounterClockwise = 1, +}; + +_MTL_ENUM(NS::UInteger, DepthClipMode) { + DepthClipModeClip = 0, + DepthClipModeClamp = 1, +}; + +_MTL_ENUM(NS::UInteger, TriangleFillMode) { + TriangleFillModeFill = 0, + TriangleFillModeLines = 1, +}; + +_MTL_OPTIONS(NS::UInteger, RenderStages) { + RenderStageVertex = 1, + RenderStageFragment = 1 << 1, + RenderStageTile = 1 << 2, + RenderStageObject = 1 << 3, + RenderStageMesh = 1 << 4, +}; + +struct ScissorRect +{ + NS::UInteger x; + NS::UInteger y; + NS::UInteger width; + NS::UInteger height; +} _MTL_PACKED; + +struct Viewport +{ + double originX; + double originY; + double width; + double height; + double znear; + double zfar; +} _MTL_PACKED; + +struct DrawPrimitivesIndirectArguments +{ + uint32_t vertexCount; + uint32_t instanceCount; + uint32_t vertexStart; + uint32_t baseInstance; +} _MTL_PACKED; + +struct DrawIndexedPrimitivesIndirectArguments +{ + uint32_t indexCount; + uint32_t instanceCount; + uint32_t indexStart; + int32_t baseVertex; + uint32_t baseInstance; +} _MTL_PACKED; + +struct VertexAmplificationViewMapping +{ + uint32_t viewportArrayIndexOffset; + uint32_t renderTargetArrayIndexOffset; +} _MTL_PACKED; + +struct DrawPatchIndirectArguments +{ + uint32_t patchCount; + uint32_t instanceCount; + uint32_t patchStart; + uint32_t baseInstance; +} _MTL_PACKED; + +struct QuadTessellationFactorsHalf +{ + uint16_t edgeTessellationFactor[4]; + uint16_t insideTessellationFactor[2]; +} _MTL_PACKED; + +struct TriangleTessellationFactorsHalf +{ + uint16_t edgeTessellationFactor[3]; + uint16_t insideTessellationFactor; +} _MTL_PACKED; + +class RenderCommandEncoder : public NS::Referencing +{ +public: + void dispatchThreadsPerTile(MTL::Size threadsPerTile); + + void drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance); + void drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + void drawMeshThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup); + + void drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount); + void drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance); + void drawPrimitives(MTL::PrimitiveType primitiveType, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange); + void executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset); + + void memoryBarrier(MTL::BarrierScope scope, MTL::RenderStages after, MTL::RenderStages before); + void memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count, MTL::RenderStages after, MTL::RenderStages before); + + void sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier); + + void setBlendColor(float red, float green, float blue, float alpha); + + void setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping); + + void setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex); + void setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex); + + void setCullMode(MTL::CullMode cullMode); + + void setDepthBias(float depthBias, float slopeScale, float clamp); + + void setDepthClipMode(MTL::DepthClipMode depthClipMode); + + void setDepthStencilState(const MTL::DepthStencilState* depthStencilState); + + void setDepthStoreAction(MTL::StoreAction storeAction); + void setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setDepthTestBounds(float minBound, float maxBound); + + void setFragmentAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setFragmentBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setFragmentBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + + void setFragmentBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setFragmentIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setFragmentIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setFragmentSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setFragmentSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setFragmentSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setFragmentSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setFragmentTexture(const MTL::Texture* texture, NS::UInteger index); + void setFragmentTextures(const MTL::Texture* const textures[], NS::Range range); + + void setFragmentVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setFragmentVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range); + + void setFrontFacingWinding(MTL::Winding frontFacingWinding); + + void setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setMeshBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setMeshBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range); + + void setMeshBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setMeshSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setMeshSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setMeshSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setMeshSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range); + + void setMeshTexture(const MTL::Texture* texture, NS::UInteger index); + void setMeshTextures(const MTL::Texture* const textures[], NS::Range range); + + void setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setObjectBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setObjectBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range); + + void setObjectBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setObjectSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setObjectSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setObjectSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setObjectSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range); + + void setObjectTexture(const MTL::Texture* texture, NS::UInteger index); + void setObjectTextures(const MTL::Texture* const textures[], NS::Range range); + + void setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index); + + void setRenderPipelineState(const MTL::RenderPipelineState* pipelineState); + + void setScissorRect(MTL::ScissorRect rect); + void setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count); + + void setStencilReferenceValue(uint32_t referenceValue); + void setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue); + + void setStencilStoreAction(MTL::StoreAction storeAction); + void setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setTessellationFactorBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride); + + void setTessellationFactorScale(float scale); + + void setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index); + + void setTileAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setTileBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setTileBufferOffset(NS::UInteger offset, NS::UInteger index); + + void setTileBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range); + + void setTileBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + + void setTileIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setTileIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setTileSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setTileSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setTileSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setTileSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setTileTexture(const MTL::Texture* texture, NS::UInteger index); + void setTileTextures(const MTL::Texture* const textures[], NS::Range range); + + void setTileVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setTileVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range); + + void setTriangleFillMode(MTL::TriangleFillMode fillMode); + + void setVertexAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex); + + void setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings); + + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index); + void setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + void setVertexBufferOffset(NS::UInteger offset, NS::UInteger index); + void setVertexBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index); + + void setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range); + void setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range); + + void setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger index); + void setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index); + + void setVertexIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex); + void setVertexIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range); + + void setVertexSamplerState(const MTL::SamplerState* sampler, NS::UInteger index); + void setVertexSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index); + void setVertexSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range); + void setVertexSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range); + + void setVertexTexture(const MTL::Texture* texture, NS::UInteger index); + void setVertexTextures(const MTL::Texture* const textures[], NS::Range range); + + void setVertexVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex); + void setVertexVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range); + + void setViewport(MTL::Viewport viewport); + void setViewports(const MTL::Viewport* viewports, NS::UInteger count); + + void setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset); + + void textureBarrier(); + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + void updateFence(const MTL::Fence* fence, MTL::RenderStages stages); + + void useHeap(const MTL::Heap* heap); + void useHeap(const MTL::Heap* heap, MTL::RenderStages stages); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count); + void useHeaps(const MTL::Heap* const heaps[], NS::UInteger count, MTL::RenderStages stages); + + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage); + void useResource(const MTL::Resource* resource, MTL::ResourceUsage usage, MTL::RenderStages stages); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage); + void useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage, MTL::RenderStages stages); + + void waitForFence(const MTL::Fence* fence, MTL::RenderStages stages); +}; + +} + +_MTL_INLINE void MTL::RenderCommandEncoder::dispatchThreadsPerTile(MTL::Size threadsPerTile) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(dispatchThreadsPerTile_), threadsPerTile); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_instanceCount_baseInstance_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, controlPointIndexBuffer, controlPointIndexBufferOffset, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* controlPointIndexBuffer, NS::UInteger controlPointIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPatches_patchIndexBuffer_patchIndexBufferOffset_controlPointIndexBuffer_controlPointIndexBufferOffset_indirectBuffer_indirectBufferOffset_), numberOfPatchControlPoints, patchIndexBuffer, patchIndexBufferOffset, controlPointIndexBuffer, controlPointIndexBufferOffset, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset, instanceCount); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger indexCount, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, NS::UInteger instanceCount, NS::Integer baseVertex, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexCount_indexType_indexBuffer_indexBufferOffset_instanceCount_baseVertex_baseInstance_), primitiveType, indexCount, indexType, indexBuffer, indexBufferOffset, instanceCount, baseVertex, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawIndexedPrimitives(MTL::PrimitiveType primitiveType, MTL::IndexType indexType, const MTL::Buffer* indexBuffer, NS::UInteger indexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawIndexedPrimitives_indexType_indexBuffer_indexBufferOffset_indirectBuffer_indirectBufferOffset_), primitiveType, indexType, indexBuffer, indexBufferOffset, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawMeshThreadgroups(MTL::Size threadgroupsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroups_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadgroupsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawMeshThreadgroups(const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreadgroupsWithIndirectBuffer_indirectBufferOffset_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), indirectBuffer, indirectBufferOffset, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawMeshThreads(MTL::Size threadsPerGrid, MTL::Size threadsPerObjectThreadgroup, MTL::Size threadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawMeshThreads_threadsPerObjectThreadgroup_threadsPerMeshThreadgroup_), threadsPerGrid, threadsPerObjectThreadgroup, threadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPatches(NS::UInteger numberOfPatchControlPoints, NS::UInteger patchStart, NS::UInteger patchCount, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPatches_patchStart_patchCount_patchIndexBuffer_patchIndexBufferOffset_instanceCount_baseInstance_), numberOfPatchControlPoints, patchStart, patchCount, patchIndexBuffer, patchIndexBufferOffset, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPatches(NS::UInteger numberOfPatchControlPoints, const MTL::Buffer* patchIndexBuffer, NS::UInteger patchIndexBufferOffset, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPatches_patchIndexBuffer_patchIndexBufferOffset_indirectBuffer_indirectBufferOffset_), numberOfPatchControlPoints, patchIndexBuffer, patchIndexBufferOffset, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_), primitiveType, vertexStart, vertexCount, instanceCount); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_), primitiveType, vertexStart, vertexCount); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, NS::UInteger vertexStart, NS::UInteger vertexCount, NS::UInteger instanceCount, NS::UInteger baseInstance) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_vertexStart_vertexCount_instanceCount_baseInstance_), primitiveType, vertexStart, vertexCount, instanceCount, baseInstance); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::drawPrimitives(MTL::PrimitiveType primitiveType, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(drawPrimitives_indirectBuffer_indirectBufferOffset_), primitiveType, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandBuffer, NS::Range executionRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_withRange_), indirectCommandBuffer, executionRange); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::executeCommandsInBuffer(const MTL::IndirectCommandBuffer* indirectCommandbuffer, const MTL::Buffer* indirectRangeBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(executeCommandsInBuffer_indirectBuffer_indirectBufferOffset_), indirectCommandbuffer, indirectRangeBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::memoryBarrier(MTL::BarrierScope scope, MTL::RenderStages after, MTL::RenderStages before) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithScope_afterStages_beforeStages_), scope, after, before); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::memoryBarrier(const MTL::Resource* const resources[], NS::UInteger count, MTL::RenderStages after, MTL::RenderStages before) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(memoryBarrierWithResources_count_afterStages_beforeStages_), resources, count, after, before); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::sampleCountersInBuffer(const MTL::CounterSampleBuffer* sampleBuffer, NS::UInteger sampleIndex, bool barrier) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCountersInBuffer_atSampleIndex_withBarrier_), sampleBuffer, sampleIndex, barrier); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setBlendColor(float red, float green, float blue, float alpha) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendColorRed_green_blue_alpha_), red, green, blue, alpha); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setColorAttachmentMap(const MTL::LogicalToPhysicalColorAttachmentMap* mapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorAttachmentMap_), mapping); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setColorStoreAction(MTL::StoreAction storeAction, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreAction_atIndex_), storeAction, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setColorStoreActionOptions(MTL::StoreActionOptions storeActionOptions, NS::UInteger colorAttachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setColorStoreActionOptions_atIndex_), storeActionOptions, colorAttachmentIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setCullMode(MTL::CullMode cullMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCullMode_), cullMode); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthBias(float depthBias, float slopeScale, float clamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthBias_slopeScale_clamp_), depthBias, slopeScale, clamp); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthClipMode(MTL::DepthClipMode depthClipMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthClipMode_), depthClipMode); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthStencilState(const MTL::DepthStencilState* depthStencilState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStencilState_), depthStencilState); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setDepthTestBounds(float minBound, float maxBound) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthTestMinBound_maxBound_), minBound, maxBound); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFragmentVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentVisibleFunctionTables_withBufferRange_), functionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setFrontFacingWinding(MTL::Winding frontFacingWinding) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFrontFacingWinding_), frontFacingWinding); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setMeshTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectSamplerStates(const MTL::SamplerState* const samplers[], const float* lodMinClamps, const float* lodMaxClamps, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setObjectThreadgroupMemoryLength(NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupMemoryLength_atIndex_), length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setRenderPipelineState(const MTL::RenderPipelineState* pipelineState) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderPipelineState_), pipelineState); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setScissorRect(MTL::ScissorRect rect) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRect_), rect); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setScissorRects(const MTL::ScissorRect* scissorRects, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setScissorRects_count_), scissorRects, count); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilReferenceValue(uint32_t referenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilReferenceValue_), referenceValue); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilReferenceValues(uint32_t frontReferenceValue, uint32_t backReferenceValue) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilFrontReferenceValue_backReferenceValue_), frontReferenceValue, backReferenceValue); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setStencilStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTessellationFactorBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger instanceStride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorBuffer_offset_instanceStride_), buffer, offset, instanceStride); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTessellationFactorScale(float scale) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorScale_), scale); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setThreadgroupMemoryLength(NS::UInteger length, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_offset_atIndex_), length, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTileVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileVisibleFunctionTables_withBufferRange_), functionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setTriangleFillMode(MTL::TriangleFillMode fillMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTriangleFillMode_), fillMode); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexAccelerationStructure(const MTL::AccelerationStructure* accelerationStructure, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAccelerationStructure_atBufferIndex_), accelerationStructure, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexAmplificationCount(NS::UInteger count, const MTL::VertexAmplificationViewMapping* viewMappings) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAmplificationCount_viewMappings_), count, viewMappings); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_atIndex_), buffer, offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffer(const MTL::Buffer* buffer, NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffer_offset_attributeStride_atIndex_), buffer, offset, stride, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBufferOffset(NS::UInteger offset, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBufferOffset_atIndex_), offset, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBufferOffset(NS::UInteger offset, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBufferOffset_attributeStride_atIndex_), offset, stride, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger offsets[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_offsets_withRange_), buffers, offsets, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBuffers(const MTL::Buffer* const buffers[], const NS::UInteger* offsets, const NS::UInteger* strides, NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBuffers_offsets_attributeStrides_withRange_), buffers, offsets, strides, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBytes_length_atIndex_), bytes, length, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexBytes(const void* bytes, NS::UInteger length, NS::UInteger stride, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexBytes_length_attributeStride_atIndex_), bytes, length, stride, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexIntersectionFunctionTable(const MTL::IntersectionFunctionTable* intersectionFunctionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexIntersectionFunctionTable_atBufferIndex_), intersectionFunctionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexIntersectionFunctionTables(const MTL::IntersectionFunctionTable* const intersectionFunctionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexIntersectionFunctionTables_withBufferRange_), intersectionFunctionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerState(const MTL::SamplerState* sampler, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerState_atIndex_), sampler, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerState(const MTL::SamplerState* sampler, float lodMinClamp, float lodMaxClamp, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerState_lodMinClamp_lodMaxClamp_atIndex_), sampler, lodMinClamp, lodMaxClamp, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerStates(const MTL::SamplerState* const samplers[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerStates_withRange_), samplers, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexSamplerStates(const MTL::SamplerState* const samplers[], const float lodMinClamps[], const float lodMaxClamps[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexSamplerStates_lodMinClamps_lodMaxClamps_withRange_), samplers, lodMinClamps, lodMaxClamps, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexTexture(const MTL::Texture* texture, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexTexture_atIndex_), texture, index); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexTextures(const MTL::Texture* const textures[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexTextures_withRange_), textures, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexVisibleFunctionTable(const MTL::VisibleFunctionTable* functionTable, NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexVisibleFunctionTable_atBufferIndex_), functionTable, bufferIndex); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVertexVisibleFunctionTables(const MTL::VisibleFunctionTable* const functionTables[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexVisibleFunctionTables_withBufferRange_), functionTables, range); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setViewport(MTL::Viewport viewport) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewport_), viewport); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setViewports(const MTL::Viewport* viewports, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setViewports_count_), viewports, count); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::setVisibilityResultMode(MTL::VisibilityResultMode mode, NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultMode_offset_), mode, offset); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::textureBarrier() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(textureBarrier)); +} + +_MTL_INLINE NS::UInteger MTL::RenderCommandEncoder::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL::RenderCommandEncoder::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::updateFence(const MTL::Fence* fence, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_afterStages_), fence, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeap(const MTL::Heap* heap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_), heap); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeap(const MTL::Heap* heap, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeap_stages_), heap, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_), heaps, count); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useHeaps(const MTL::Heap* const heaps[], NS::UInteger count, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useHeaps_count_stages_), heaps, count, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_), resource, usage); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResource(const MTL::Resource* resource, MTL::ResourceUsage usage, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResource_usage_stages_), resource, usage, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_), resources, count, usage); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::useResources(const MTL::Resource* const resources[], NS::UInteger count, MTL::ResourceUsage usage, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(useResources_count_usage_stages_), resources, count, usage, stages); +} + +_MTL_INLINE void MTL::RenderCommandEncoder::waitForFence(const MTL::Fence* fence, MTL::RenderStages stages) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_beforeStages_), fence, stages); +} diff --git a/dist/include/metal_cpp/Metal/MTLRenderPass.hpp b/dist/include/metal_cpp/Metal/MTLRenderPass.hpp new file mode 100644 index 0000000..ed2172d --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLRenderPass.hpp @@ -0,0 +1,792 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLRenderPass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Buffer; +class CounterSampleBuffer; +class RasterizationRateMap; +class RenderPassAttachmentDescriptor; +class RenderPassColorAttachmentDescriptor; +class RenderPassColorAttachmentDescriptorArray; +class RenderPassDepthAttachmentDescriptor; +class RenderPassDescriptor; +class RenderPassSampleBufferAttachmentDescriptor; +class RenderPassSampleBufferAttachmentDescriptorArray; +class RenderPassStencilAttachmentDescriptor; +struct SamplePosition; +class Texture; +_MTL_ENUM(NS::UInteger, LoadAction) { + LoadActionDontCare = 0, + LoadActionLoad = 1, + LoadActionClear = 2, +}; + +_MTL_ENUM(NS::UInteger, StoreAction) { + StoreActionDontCare = 0, + StoreActionStore = 1, + StoreActionMultisampleResolve = 2, + StoreActionStoreAndMultisampleResolve = 3, + StoreActionUnknown = 4, + StoreActionCustomSampleDepthStore = 5, +}; + +_MTL_ENUM(NS::Integer, VisibilityResultType) { + VisibilityResultTypeReset = 0, + VisibilityResultTypeAccumulate = 1, +}; + +_MTL_ENUM(NS::UInteger, MultisampleDepthResolveFilter) { + MultisampleDepthResolveFilterSample0 = 0, + MultisampleDepthResolveFilterMin = 1, + MultisampleDepthResolveFilterMax = 2, +}; + +_MTL_ENUM(NS::UInteger, MultisampleStencilResolveFilter) { + MultisampleStencilResolveFilterSample0 = 0, + MultisampleStencilResolveFilterDepthResolvedSample = 1, +}; + +_MTL_OPTIONS(NS::UInteger, StoreActionOptions) { + StoreActionOptionNone = 0, + StoreActionOptionCustomSamplePositions = 1, + StoreActionOptionValidMask = 1, +}; + +struct ClearColor +{ + ClearColor() = default; + + ClearColor(double red, double green, double blue, double alpha); + + static ClearColor Make(double red, double green, double blue, double alpha); + + double red; + double green; + double blue; + double alpha; +} _MTL_PACKED; + +class RenderPassAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassAttachmentDescriptor* alloc(); + + NS::UInteger depthPlane() const; + + RenderPassAttachmentDescriptor* init(); + + NS::UInteger level() const; + + LoadAction loadAction() const; + + NS::UInteger resolveDepthPlane() const; + + NS::UInteger resolveLevel() const; + + NS::UInteger resolveSlice() const; + + Texture* resolveTexture() const; + + void setDepthPlane(NS::UInteger depthPlane); + + void setLevel(NS::UInteger level); + + void setLoadAction(MTL::LoadAction loadAction); + + void setResolveDepthPlane(NS::UInteger resolveDepthPlane); + + void setResolveLevel(NS::UInteger resolveLevel); + + void setResolveSlice(NS::UInteger resolveSlice); + + void setResolveTexture(const MTL::Texture* resolveTexture); + + void setSlice(NS::UInteger slice); + + void setStoreAction(MTL::StoreAction storeAction); + void setStoreActionOptions(MTL::StoreActionOptions storeActionOptions); + + void setTexture(const MTL::Texture* texture); + + NS::UInteger slice() const; + + StoreAction storeAction() const; + StoreActionOptions storeActionOptions() const; + + Texture* texture() const; +}; +class RenderPassColorAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassColorAttachmentDescriptor* alloc(); + + ClearColor clearColor() const; + + RenderPassColorAttachmentDescriptor* init(); + + void setClearColor(MTL::ClearColor clearColor); +}; +class RenderPassDepthAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassDepthAttachmentDescriptor* alloc(); + + double clearDepth() const; + + MultisampleDepthResolveFilter depthResolveFilter() const; + + RenderPassDepthAttachmentDescriptor* init(); + + void setClearDepth(double clearDepth); + + void setDepthResolveFilter(MTL::MultisampleDepthResolveFilter depthResolveFilter); +}; +class RenderPassStencilAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassStencilAttachmentDescriptor* alloc(); + + uint32_t clearStencil() const; + + RenderPassStencilAttachmentDescriptor* init(); + + void setClearStencil(uint32_t clearStencil); + + void setStencilResolveFilter(MTL::MultisampleStencilResolveFilter stencilResolveFilter); + MultisampleStencilResolveFilter stencilResolveFilter() const; +}; +class RenderPassColorAttachmentDescriptorArray : public NS::Referencing +{ +public: + static RenderPassColorAttachmentDescriptorArray* alloc(); + + RenderPassColorAttachmentDescriptorArray* init(); + + RenderPassColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::RenderPassColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class RenderPassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfFragmentSampleIndex() const; + + NS::UInteger endOfVertexSampleIndex() const; + + RenderPassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfFragmentSampleIndex(NS::UInteger endOfFragmentSampleIndex); + + void setEndOfVertexSampleIndex(NS::UInteger endOfVertexSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfFragmentSampleIndex(NS::UInteger startOfFragmentSampleIndex); + + void setStartOfVertexSampleIndex(NS::UInteger startOfVertexSampleIndex); + + NS::UInteger startOfFragmentSampleIndex() const; + + NS::UInteger startOfVertexSampleIndex() const; +}; +class RenderPassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static RenderPassSampleBufferAttachmentDescriptorArray* alloc(); + + RenderPassSampleBufferAttachmentDescriptorArray* init(); + + RenderPassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::RenderPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class RenderPassDescriptor : public NS::Copying +{ +public: + static RenderPassDescriptor* alloc(); + + RenderPassColorAttachmentDescriptorArray* colorAttachments() const; + + NS::UInteger defaultRasterSampleCount() const; + + RenderPassDepthAttachmentDescriptor* depthAttachment() const; + + NS::UInteger getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count); + + NS::UInteger imageblockSampleLength() const; + + RenderPassDescriptor* init(); + + RasterizationRateMap* rasterizationRateMap() const; + + static RenderPassDescriptor* renderPassDescriptor(); + + NS::UInteger renderTargetArrayLength() const; + + NS::UInteger renderTargetHeight() const; + + NS::UInteger renderTargetWidth() const; + + RenderPassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; + + void setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount); + + void setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment); + + void setImageblockSampleLength(NS::UInteger imageblockSampleLength); + + void setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap); + + void setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength); + + void setRenderTargetHeight(NS::UInteger renderTargetHeight); + + void setRenderTargetWidth(NS::UInteger renderTargetWidth); + + void setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count); + + void setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment); + + void setSupportColorAttachmentMapping(bool supportColorAttachmentMapping); + + void setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength); + + void setTileHeight(NS::UInteger tileHeight); + + void setTileWidth(NS::UInteger tileWidth); + + void setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer); + + void setVisibilityResultType(MTL::VisibilityResultType visibilityResultType); + + RenderPassStencilAttachmentDescriptor* stencilAttachment() const; + + bool supportColorAttachmentMapping() const; + + NS::UInteger threadgroupMemoryLength() const; + + NS::UInteger tileHeight() const; + + NS::UInteger tileWidth() const; + + Buffer* visibilityResultBuffer() const; + + VisibilityResultType visibilityResultType() const; +}; + +} +_MTL_INLINE MTL::ClearColor::ClearColor(double red, double green, double blue, double alpha) + : red(red) + , green(green) + , blue(blue) + , alpha(alpha) +{ +} + +_MTL_INLINE MTL::ClearColor MTL::ClearColor::Make(double red, double green, double blue, double alpha) +{ + return ClearColor(red, green, blue, alpha); +} + +_MTL_INLINE MTL::RenderPassAttachmentDescriptor* MTL::RenderPassAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::depthPlane() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthPlane)); +} + +_MTL_INLINE MTL::RenderPassAttachmentDescriptor* MTL::RenderPassAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::level() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(level)); +} + +_MTL_INLINE MTL::LoadAction MTL::RenderPassAttachmentDescriptor::loadAction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(loadAction)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::resolveDepthPlane() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveDepthPlane)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::resolveLevel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveLevel)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::resolveSlice() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveSlice)); +} + +_MTL_INLINE MTL::Texture* MTL::RenderPassAttachmentDescriptor::resolveTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resolveTexture)); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setDepthPlane(NS::UInteger depthPlane) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthPlane_), depthPlane); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setLevel(NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLevel_), level); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setLoadAction(MTL::LoadAction loadAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLoadAction_), loadAction); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveDepthPlane(NS::UInteger resolveDepthPlane) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveDepthPlane_), resolveDepthPlane); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveLevel(NS::UInteger resolveLevel) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveLevel_), resolveLevel); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveSlice(NS::UInteger resolveSlice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveSlice_), resolveSlice); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setResolveTexture(const MTL::Texture* resolveTexture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResolveTexture_), resolveTexture); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setSlice(NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSlice_), slice); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setStoreAction(MTL::StoreAction storeAction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStoreAction_), storeAction); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setStoreActionOptions(MTL::StoreActionOptions storeActionOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStoreActionOptions_), storeActionOptions); +} + +_MTL_INLINE void MTL::RenderPassAttachmentDescriptor::setTexture(const MTL::Texture* texture) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTexture_), texture); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassAttachmentDescriptor::slice() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(slice)); +} + +_MTL_INLINE MTL::StoreAction MTL::RenderPassAttachmentDescriptor::storeAction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storeAction)); +} + +_MTL_INLINE MTL::StoreActionOptions MTL::RenderPassAttachmentDescriptor::storeActionOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storeActionOptions)); +} + +_MTL_INLINE MTL::Texture* MTL::RenderPassAttachmentDescriptor::texture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(texture)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptor* MTL::RenderPassColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::ClearColor MTL::RenderPassColorAttachmentDescriptor::clearColor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(clearColor)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptor* MTL::RenderPassColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPassColorAttachmentDescriptor::setClearColor(MTL::ClearColor clearColor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setClearColor_), clearColor); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL::RenderPassDepthAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassDepthAttachmentDescriptor)); +} + +_MTL_INLINE double MTL::RenderPassDepthAttachmentDescriptor::clearDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(clearDepth)); +} + +_MTL_INLINE MTL::MultisampleDepthResolveFilter MTL::RenderPassDepthAttachmentDescriptor::depthResolveFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthResolveFilter)); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL::RenderPassDepthAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPassDepthAttachmentDescriptor::setClearDepth(double clearDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setClearDepth_), clearDepth); +} + +_MTL_INLINE void MTL::RenderPassDepthAttachmentDescriptor::setDepthResolveFilter(MTL::MultisampleDepthResolveFilter depthResolveFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthResolveFilter_), depthResolveFilter); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL::RenderPassStencilAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassStencilAttachmentDescriptor)); +} + +_MTL_INLINE uint32_t MTL::RenderPassStencilAttachmentDescriptor::clearStencil() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(clearStencil)); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL::RenderPassStencilAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPassStencilAttachmentDescriptor::setClearStencil(uint32_t clearStencil) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setClearStencil_), clearStencil); +} + +_MTL_INLINE void MTL::RenderPassStencilAttachmentDescriptor::setStencilResolveFilter(MTL::MultisampleStencilResolveFilter stencilResolveFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilResolveFilter_), stencilResolveFilter); +} + +_MTL_INLINE MTL::MultisampleStencilResolveFilter MTL::RenderPassStencilAttachmentDescriptor::stencilResolveFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilResolveFilter)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL::RenderPassColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL::RenderPassColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptor* MTL::RenderPassColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::RenderPassColorAttachmentDescriptorArray::setObject(const MTL::RenderPassColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptor* MTL::RenderPassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::endOfFragmentSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfFragmentSampleIndex)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::endOfVertexSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfVertexSampleIndex)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptor* MTL::RenderPassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::RenderPassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setEndOfFragmentSampleIndex(NS::UInteger endOfFragmentSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfFragmentSampleIndex_), endOfFragmentSampleIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setEndOfVertexSampleIndex(NS::UInteger endOfVertexSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfVertexSampleIndex_), endOfVertexSampleIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setStartOfFragmentSampleIndex(NS::UInteger startOfFragmentSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfFragmentSampleIndex_), startOfFragmentSampleIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptor::setStartOfVertexSampleIndex(NS::UInteger startOfVertexSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfVertexSampleIndex_), startOfVertexSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::startOfFragmentSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfFragmentSampleIndex)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassSampleBufferAttachmentDescriptor::startOfVertexSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfVertexSampleIndex)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptorArray* MTL::RenderPassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptorArray* MTL::RenderPassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptor* MTL::RenderPassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::RenderPassSampleBufferAttachmentDescriptorArray::setObject(const MTL::RenderPassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::RenderPassDescriptor* MTL::RenderPassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPassDescriptor)); +} + +_MTL_INLINE MTL::RenderPassColorAttachmentDescriptorArray* MTL::RenderPassDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::defaultRasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(defaultRasterSampleCount)); +} + +_MTL_INLINE MTL::RenderPassDepthAttachmentDescriptor* MTL::RenderPassDescriptor::depthAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachment)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::getSamplePositions(MTL::SamplePosition* positions, NS::UInteger count) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(getSamplePositions_count_), positions, count); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::imageblockSampleLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockSampleLength)); +} + +_MTL_INLINE MTL::RenderPassDescriptor* MTL::RenderPassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RasterizationRateMap* MTL::RenderPassDescriptor::rasterizationRateMap() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterizationRateMap)); +} + +_MTL_INLINE MTL::RenderPassDescriptor* MTL::RenderPassDescriptor::renderPassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLRenderPassDescriptor), _MTL_PRIVATE_SEL(renderPassDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::renderTargetArrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetArrayLength)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::renderTargetHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetHeight)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::renderTargetWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(renderTargetWidth)); +} + +_MTL_INLINE MTL::RenderPassSampleBufferAttachmentDescriptorArray* MTL::RenderPassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setDefaultRasterSampleCount(NS::UInteger defaultRasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDefaultRasterSampleCount_), defaultRasterSampleCount); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setDepthAttachment(const MTL::RenderPassDepthAttachmentDescriptor* depthAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachment_), depthAttachment); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setImageblockSampleLength(NS::UInteger imageblockSampleLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setImageblockSampleLength_), imageblockSampleLength); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRasterizationRateMap(const MTL::RasterizationRateMap* rasterizationRateMap) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationRateMap_), rasterizationRateMap); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRenderTargetArrayLength(NS::UInteger renderTargetArrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetArrayLength_), renderTargetArrayLength); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRenderTargetHeight(NS::UInteger renderTargetHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetHeight_), renderTargetHeight); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setRenderTargetWidth(NS::UInteger renderTargetWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRenderTargetWidth_), renderTargetWidth); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setSamplePositions(const MTL::SamplePosition* positions, NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSamplePositions_count_), positions, count); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setStencilAttachment(const MTL::RenderPassStencilAttachmentDescriptor* stencilAttachment) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachment_), stencilAttachment); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setSupportColorAttachmentMapping(bool supportColorAttachmentMapping) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportColorAttachmentMapping_), supportColorAttachmentMapping); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setThreadgroupMemoryLength(NS::UInteger threadgroupMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupMemoryLength_), threadgroupMemoryLength); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setTileHeight(NS::UInteger tileHeight) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileHeight_), tileHeight); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setTileWidth(NS::UInteger tileWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileWidth_), tileWidth); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setVisibilityResultBuffer(const MTL::Buffer* visibilityResultBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultBuffer_), visibilityResultBuffer); +} + +_MTL_INLINE void MTL::RenderPassDescriptor::setVisibilityResultType(MTL::VisibilityResultType visibilityResultType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVisibilityResultType_), visibilityResultType); +} + +_MTL_INLINE MTL::RenderPassStencilAttachmentDescriptor* MTL::RenderPassDescriptor::stencilAttachment() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachment)); +} + +_MTL_INLINE bool MTL::RenderPassDescriptor::supportColorAttachmentMapping() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportColorAttachmentMapping)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::threadgroupMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::tileHeight() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileHeight)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPassDescriptor::tileWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileWidth)); +} + +_MTL_INLINE MTL::Buffer* MTL::RenderPassDescriptor::visibilityResultBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultBuffer)); +} + +_MTL_INLINE MTL::VisibilityResultType MTL::RenderPassDescriptor::visibilityResultType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(visibilityResultType)); +} diff --git a/dist/include/metal_cpp/Metal/MTLRenderPipeline.hpp b/dist/include/metal_cpp/Metal/MTLRenderPipeline.hpp new file mode 100644 index 0000000..aaa9cda --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLRenderPipeline.hpp @@ -0,0 +1,1876 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLRenderPipeline.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPipeline.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +class Function; +class FunctionHandle; +class IntersectionFunctionTable; +class IntersectionFunctionTableDescriptor; +class LinkedFunctions; +class LogicalToPhysicalColorAttachmentMap; +class MeshRenderPipelineDescriptor; +class PipelineBufferDescriptorArray; +class RenderPipelineColorAttachmentDescriptor; +class RenderPipelineColorAttachmentDescriptorArray; +class RenderPipelineDescriptor; +class RenderPipelineFunctionsDescriptor; +class RenderPipelineReflection; +class RenderPipelineState; +class TileRenderPipelineColorAttachmentDescriptor; +class TileRenderPipelineColorAttachmentDescriptorArray; +class TileRenderPipelineDescriptor; +class VertexDescriptor; +class VisibleFunctionTable; +class VisibleFunctionTableDescriptor; + +} +namespace MTL4 +{ +class BinaryFunction; +class PipelineDescriptor; +class RenderPipelineBinaryFunctionsDescriptor; + +} +namespace MTL +{ +_MTL_ENUM(NS::UInteger, BlendFactor) { + BlendFactorZero = 0, + BlendFactorOne = 1, + BlendFactorSourceColor = 2, + BlendFactorOneMinusSourceColor = 3, + BlendFactorSourceAlpha = 4, + BlendFactorOneMinusSourceAlpha = 5, + BlendFactorDestinationColor = 6, + BlendFactorOneMinusDestinationColor = 7, + BlendFactorDestinationAlpha = 8, + BlendFactorOneMinusDestinationAlpha = 9, + BlendFactorSourceAlphaSaturated = 10, + BlendFactorBlendColor = 11, + BlendFactorOneMinusBlendColor = 12, + BlendFactorBlendAlpha = 13, + BlendFactorOneMinusBlendAlpha = 14, + BlendFactorSource1Color = 15, + BlendFactorOneMinusSource1Color = 16, + BlendFactorSource1Alpha = 17, + BlendFactorOneMinusSource1Alpha = 18, + BlendFactorUnspecialized = 19, +}; + +_MTL_ENUM(NS::UInteger, BlendOperation) { + BlendOperationAdd = 0, + BlendOperationSubtract = 1, + BlendOperationReverseSubtract = 2, + BlendOperationMin = 3, + BlendOperationMax = 4, + BlendOperationUnspecialized = 5, +}; + +_MTL_ENUM(NS::UInteger, PrimitiveTopologyClass) { + PrimitiveTopologyClassUnspecified = 0, + PrimitiveTopologyClassPoint = 1, + PrimitiveTopologyClassLine = 2, + PrimitiveTopologyClassTriangle = 3, +}; + +_MTL_ENUM(NS::UInteger, TessellationPartitionMode) { + TessellationPartitionModePow2 = 0, + TessellationPartitionModeInteger = 1, + TessellationPartitionModeFractionalOdd = 2, + TessellationPartitionModeFractionalEven = 3, +}; + +_MTL_ENUM(NS::UInteger, TessellationFactorStepFunction) { + TessellationFactorStepFunctionConstant = 0, + TessellationFactorStepFunctionPerPatch = 1, + TessellationFactorStepFunctionPerInstance = 2, + TessellationFactorStepFunctionPerPatchAndPerInstance = 3, +}; + +_MTL_ENUM(NS::UInteger, TessellationFactorFormat) { + TessellationFactorFormatHalf = 0, +}; + +_MTL_ENUM(NS::UInteger, TessellationControlPointIndexType) { + TessellationControlPointIndexTypeNone = 0, + TessellationControlPointIndexTypeUInt16 = 1, + TessellationControlPointIndexTypeUInt32 = 2, +}; + +_MTL_OPTIONS(NS::UInteger, ColorWriteMask) { + ColorWriteMaskNone = 0, + ColorWriteMaskRed = 1 << 3, + ColorWriteMaskGreen = 1 << 2, + ColorWriteMaskBlue = 1 << 1, + ColorWriteMaskAlpha = 1, + ColorWriteMaskAll = 15, + ColorWriteMaskUnspecialized = 1 << 4, +}; + +class RenderPipelineColorAttachmentDescriptor : public NS::Copying +{ +public: + static RenderPipelineColorAttachmentDescriptor* alloc(); + + BlendOperation alphaBlendOperation() const; + + [[deprecated("please use isBlendingEnabled instead")]] + bool blendingEnabled() const; + + BlendFactor destinationAlphaBlendFactor() const; + + BlendFactor destinationRGBBlendFactor() const; + + RenderPipelineColorAttachmentDescriptor* init(); + + bool isBlendingEnabled() const; + + PixelFormat pixelFormat() const; + + BlendOperation rgbBlendOperation() const; + + void setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation); + + void setBlendingEnabled(bool blendingEnabled); + + void setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor); + + void setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation); + + void setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor); + + void setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor); + + void setWriteMask(MTL::ColorWriteMask writeMask); + + BlendFactor sourceAlphaBlendFactor() const; + + BlendFactor sourceRGBBlendFactor() const; + + ColorWriteMask writeMask() const; +}; +class LogicalToPhysicalColorAttachmentMap : public NS::Copying +{ +public: + static LogicalToPhysicalColorAttachmentMap* alloc(); + + NS::UInteger getPhysicalIndex(NS::UInteger logicalIndex); + + LogicalToPhysicalColorAttachmentMap* init(); + + void reset(); + + void setPhysicalIndex(NS::UInteger physicalIndex, NS::UInteger logicalIndex); +}; +class RenderPipelineReflection : public NS::Referencing +{ +public: + static RenderPipelineReflection* alloc(); + + NS::Array* fragmentArguments() const; + + NS::Array* fragmentBindings() const; + + RenderPipelineReflection* init(); + + NS::Array* meshBindings() const; + + NS::Array* objectBindings() const; + + NS::Array* tileArguments() const; + + NS::Array* tileBindings() const; + + NS::Array* vertexArguments() const; + + NS::Array* vertexBindings() const; +}; +class RenderPipelineDescriptor : public NS::Copying +{ +public: + static RenderPipelineDescriptor* alloc(); + + [[deprecated("please use isAlphaToCoverageEnabled instead")]] + bool alphaToCoverageEnabled() const; + + [[deprecated("please use isAlphaToOneEnabled instead")]] + bool alphaToOneEnabled() const; + + NS::Array* binaryArchives() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + PixelFormat depthAttachmentPixelFormat() const; + + PipelineBufferDescriptorArray* fragmentBuffers() const; + + Function* fragmentFunction() const; + + LinkedFunctions* fragmentLinkedFunctions() const; + + NS::Array* fragmentPreloadedLibraries() const; + + RenderPipelineDescriptor* init(); + + PrimitiveTopologyClass inputPrimitiveTopology() const; + + bool isAlphaToCoverageEnabled() const; + + bool isAlphaToOneEnabled() const; + + bool isRasterizationEnabled() const; + + bool isTessellationFactorScaleEnabled() const; + + NS::String* label() const; + + NS::UInteger maxFragmentCallStackDepth() const; + + NS::UInteger maxTessellationFactor() const; + + NS::UInteger maxVertexAmplificationCount() const; + + NS::UInteger maxVertexCallStackDepth() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + void reset(); + + NS::UInteger sampleCount() const; + + void setAlphaToCoverageEnabled(bool alphaToCoverageEnabled); + + void setAlphaToOneEnabled(bool alphaToOneEnabled); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat); + + void setFragmentFunction(const MTL::Function* fragmentFunction); + + void setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions); + + void setFragmentPreloadedLibraries(const NS::Array* fragmentPreloadedLibraries); + + void setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology); + + void setLabel(const NS::String* label); + + void setMaxFragmentCallStackDepth(NS::UInteger maxFragmentCallStackDepth); + + void setMaxTessellationFactor(NS::UInteger maxTessellationFactor); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setMaxVertexCallStackDepth(NS::UInteger maxVertexCallStackDepth); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setSampleCount(NS::UInteger sampleCount); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat); + + void setSupportAddingFragmentBinaryFunctions(bool supportAddingFragmentBinaryFunctions); + + void setSupportAddingVertexBinaryFunctions(bool supportAddingVertexBinaryFunctions); + + void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers); + + void setTessellationControlPointIndexType(MTL::TessellationControlPointIndexType tessellationControlPointIndexType); + + void setTessellationFactorFormat(MTL::TessellationFactorFormat tessellationFactorFormat); + + void setTessellationFactorScaleEnabled(bool tessellationFactorScaleEnabled); + + void setTessellationFactorStepFunction(MTL::TessellationFactorStepFunction tessellationFactorStepFunction); + + void setTessellationOutputWindingOrder(MTL::Winding tessellationOutputWindingOrder); + + void setTessellationPartitionMode(MTL::TessellationPartitionMode tessellationPartitionMode); + + void setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor); + + void setVertexFunction(const MTL::Function* vertexFunction); + + void setVertexLinkedFunctions(const MTL::LinkedFunctions* vertexLinkedFunctions); + + void setVertexPreloadedLibraries(const NS::Array* vertexPreloadedLibraries); + + ShaderValidation shaderValidation() const; + + PixelFormat stencilAttachmentPixelFormat() const; + + bool supportAddingFragmentBinaryFunctions() const; + + bool supportAddingVertexBinaryFunctions() const; + + bool supportIndirectCommandBuffers() const; + + TessellationControlPointIndexType tessellationControlPointIndexType() const; + + TessellationFactorFormat tessellationFactorFormat() const; + + [[deprecated("please use isTessellationFactorScaleEnabled instead")]] + bool tessellationFactorScaleEnabled() const; + + TessellationFactorStepFunction tessellationFactorStepFunction() const; + + Winding tessellationOutputWindingOrder() const; + + TessellationPartitionMode tessellationPartitionMode() const; + + PipelineBufferDescriptorArray* vertexBuffers() const; + + VertexDescriptor* vertexDescriptor() const; + + Function* vertexFunction() const; + + LinkedFunctions* vertexLinkedFunctions() const; + + NS::Array* vertexPreloadedLibraries() const; +}; +class RenderPipelineFunctionsDescriptor : public NS::Copying +{ +public: + static RenderPipelineFunctionsDescriptor* alloc(); + + NS::Array* fragmentAdditionalBinaryFunctions() const; + + RenderPipelineFunctionsDescriptor* init(); + + void setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions); + + void setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions); + + void setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions); + + NS::Array* tileAdditionalBinaryFunctions() const; + + NS::Array* vertexAdditionalBinaryFunctions() const; +}; +class RenderPipelineState : public NS::Referencing +{ +public: + Device* device() const; + + FunctionHandle* functionHandle(const NS::String* name, MTL::RenderStages stage); + FunctionHandle* functionHandle(const MTL4::BinaryFunction* function, MTL::RenderStages stage); + FunctionHandle* functionHandle(const MTL::Function* function, MTL::RenderStages stage); + + ResourceID gpuResourceID() const; + + NS::UInteger imageblockMemoryLength(MTL::Size imageblockDimensions); + + NS::UInteger imageblockSampleLength() const; + + NS::String* label() const; + + NS::UInteger maxTotalThreadgroupsPerMeshGrid() const; + + NS::UInteger maxTotalThreadsPerMeshThreadgroup() const; + + NS::UInteger maxTotalThreadsPerObjectThreadgroup() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::UInteger meshThreadExecutionWidth() const; + + IntersectionFunctionTable* newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor, MTL::RenderStages stage); + + MTL4::PipelineDescriptor* newRenderPipelineDescriptor(); + + RenderPipelineState* newRenderPipelineState(const MTL4::RenderPipelineBinaryFunctionsDescriptor* binaryFunctionsDescriptor, NS::Error** error); + RenderPipelineState* newRenderPipelineState(const MTL::RenderPipelineFunctionsDescriptor* additionalBinaryFunctions, NS::Error** error); + + VisibleFunctionTable* newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor, MTL::RenderStages stage); + + NS::UInteger objectThreadExecutionWidth() const; + + RenderPipelineReflection* reflection() const; + + Size requiredThreadsPerMeshThreadgroup() const; + + Size requiredThreadsPerObjectThreadgroup() const; + + Size requiredThreadsPerTileThreadgroup() const; + + ShaderValidation shaderValidation() const; + + bool supportIndirectCommandBuffers() const; + + bool threadgroupSizeMatchesTileSize() const; +}; +class RenderPipelineColorAttachmentDescriptorArray : public NS::Referencing +{ +public: + static RenderPipelineColorAttachmentDescriptorArray* alloc(); + + RenderPipelineColorAttachmentDescriptorArray* init(); + + RenderPipelineColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class TileRenderPipelineColorAttachmentDescriptor : public NS::Copying +{ +public: + static TileRenderPipelineColorAttachmentDescriptor* alloc(); + + TileRenderPipelineColorAttachmentDescriptor* init(); + + PixelFormat pixelFormat() const; + void setPixelFormat(MTL::PixelFormat pixelFormat); +}; +class TileRenderPipelineColorAttachmentDescriptorArray : public NS::Referencing +{ +public: + static TileRenderPipelineColorAttachmentDescriptorArray* alloc(); + + TileRenderPipelineColorAttachmentDescriptorArray* init(); + + TileRenderPipelineColorAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::TileRenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class TileRenderPipelineDescriptor : public NS::Copying +{ +public: + static TileRenderPipelineDescriptor* alloc(); + + NS::Array* binaryArchives() const; + + TileRenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + TileRenderPipelineDescriptor* init(); + + NS::String* label() const; + + LinkedFunctions* linkedFunctions() const; + + NS::UInteger maxCallStackDepth() const; + + NS::UInteger maxTotalThreadsPerThreadgroup() const; + + NS::Array* preloadedLibraries() const; + + NS::UInteger rasterSampleCount() const; + + Size requiredThreadsPerThreadgroup() const; + + void reset(); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setLabel(const NS::String* label); + + void setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions); + + void setMaxCallStackDepth(NS::UInteger maxCallStackDepth); + + void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup); + + void setPreloadedLibraries(const NS::Array* preloadedLibraries); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions); + + void setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize); + + void setTileFunction(const MTL::Function* tileFunction); + + ShaderValidation shaderValidation() const; + + bool supportAddingBinaryFunctions() const; + + bool threadgroupSizeMatchesTileSize() const; + + PipelineBufferDescriptorArray* tileBuffers() const; + + Function* tileFunction() const; +}; +class MeshRenderPipelineDescriptor : public NS::Copying +{ +public: + static MeshRenderPipelineDescriptor* alloc(); + + [[deprecated("please use isAlphaToCoverageEnabled instead")]] + bool alphaToCoverageEnabled() const; + + [[deprecated("please use isAlphaToOneEnabled instead")]] + bool alphaToOneEnabled() const; + + NS::Array* binaryArchives() const; + + RenderPipelineColorAttachmentDescriptorArray* colorAttachments() const; + + PixelFormat depthAttachmentPixelFormat() const; + + PipelineBufferDescriptorArray* fragmentBuffers() const; + + Function* fragmentFunction() const; + + LinkedFunctions* fragmentLinkedFunctions() const; + + MeshRenderPipelineDescriptor* init(); + + bool isAlphaToCoverageEnabled() const; + + bool isAlphaToOneEnabled() const; + + bool isRasterizationEnabled() const; + + NS::String* label() const; + + NS::UInteger maxTotalThreadgroupsPerMeshGrid() const; + + NS::UInteger maxTotalThreadsPerMeshThreadgroup() const; + + NS::UInteger maxTotalThreadsPerObjectThreadgroup() const; + + NS::UInteger maxVertexAmplificationCount() const; + + PipelineBufferDescriptorArray* meshBuffers() const; + + Function* meshFunction() const; + + LinkedFunctions* meshLinkedFunctions() const; + + bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + PipelineBufferDescriptorArray* objectBuffers() const; + + Function* objectFunction() const; + + LinkedFunctions* objectLinkedFunctions() const; + + bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const; + + NS::UInteger payloadMemoryLength() const; + + NS::UInteger rasterSampleCount() const; + + [[deprecated("please use isRasterizationEnabled instead")]] + bool rasterizationEnabled() const; + + Size requiredThreadsPerMeshThreadgroup() const; + + Size requiredThreadsPerObjectThreadgroup() const; + + void reset(); + + void setAlphaToCoverageEnabled(bool alphaToCoverageEnabled); + + void setAlphaToOneEnabled(bool alphaToOneEnabled); + + void setBinaryArchives(const NS::Array* binaryArchives); + + void setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat); + + void setFragmentFunction(const MTL::Function* fragmentFunction); + + void setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions); + + void setLabel(const NS::String* label); + + void setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid); + + void setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup); + + void setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup); + + void setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount); + + void setMeshFunction(const MTL::Function* meshFunction); + + void setMeshLinkedFunctions(const MTL::LinkedFunctions* meshLinkedFunctions); + + void setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setObjectFunction(const MTL::Function* objectFunction); + + void setObjectLinkedFunctions(const MTL::LinkedFunctions* objectLinkedFunctions); + + void setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); + + void setPayloadMemoryLength(NS::UInteger payloadMemoryLength); + + void setRasterSampleCount(NS::UInteger rasterSampleCount); + + void setRasterizationEnabled(bool rasterizationEnabled); + + void setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup); + + void setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup); + + void setShaderValidation(MTL::ShaderValidation shaderValidation); + + void setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat); + + void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers); + + ShaderValidation shaderValidation() const; + + PixelFormat stencilAttachmentPixelFormat() const; + + bool supportIndirectCommandBuffers() const; +}; + +} +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptor* MTL::RenderPipelineColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::BlendOperation MTL::RenderPipelineColorAttachmentDescriptor::alphaBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(alphaBlendOperation)); +} + +_MTL_INLINE bool MTL::RenderPipelineColorAttachmentDescriptor::blendingEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isBlendingEnabled)); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::destinationAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::destinationRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(destinationRGBBlendFactor)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptor* MTL::RenderPipelineColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::RenderPipelineColorAttachmentDescriptor::isBlendingEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isBlendingEnabled)); +} + +_MTL_INLINE MTL::PixelFormat MTL::RenderPipelineColorAttachmentDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE MTL::BlendOperation MTL::RenderPipelineColorAttachmentDescriptor::rgbBlendOperation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rgbBlendOperation)); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setAlphaBlendOperation(MTL::BlendOperation alphaBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaBlendOperation_), alphaBlendOperation); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setBlendingEnabled(bool blendingEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBlendingEnabled_), blendingEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setDestinationAlphaBlendFactor(MTL::BlendFactor destinationAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationAlphaBlendFactor_), destinationAlphaBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setDestinationRGBBlendFactor(MTL::BlendFactor destinationRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDestinationRGBBlendFactor_), destinationRGBBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setRgbBlendOperation(MTL::BlendOperation rgbBlendOperation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRgbBlendOperation_), rgbBlendOperation); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setSourceAlphaBlendFactor(MTL::BlendFactor sourceAlphaBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceAlphaBlendFactor_), sourceAlphaBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setSourceRGBBlendFactor(MTL::BlendFactor sourceRGBBlendFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSourceRGBBlendFactor_), sourceRGBBlendFactor); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptor::setWriteMask(MTL::ColorWriteMask writeMask) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWriteMask_), writeMask); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::sourceAlphaBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceAlphaBlendFactor)); +} + +_MTL_INLINE MTL::BlendFactor MTL::RenderPipelineColorAttachmentDescriptor::sourceRGBBlendFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sourceRGBBlendFactor)); +} + +_MTL_INLINE MTL::ColorWriteMask MTL::RenderPipelineColorAttachmentDescriptor::writeMask() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(writeMask)); +} + +_MTL_INLINE MTL::LogicalToPhysicalColorAttachmentMap* MTL::LogicalToPhysicalColorAttachmentMap::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLLogicalToPhysicalColorAttachmentMap)); +} + +_MTL_INLINE NS::UInteger MTL::LogicalToPhysicalColorAttachmentMap::getPhysicalIndex(NS::UInteger logicalIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(getPhysicalIndexForLogicalIndex_), logicalIndex); +} + +_MTL_INLINE MTL::LogicalToPhysicalColorAttachmentMap* MTL::LogicalToPhysicalColorAttachmentMap::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::LogicalToPhysicalColorAttachmentMap::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::LogicalToPhysicalColorAttachmentMap::setPhysicalIndex(NS::UInteger physicalIndex, NS::UInteger logicalIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPhysicalIndex_forLogicalIndex_), physicalIndex, logicalIndex); +} + +_MTL_INLINE MTL::RenderPipelineReflection* MTL::RenderPipelineReflection::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineReflection)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::fragmentArguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentArguments)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::fragmentBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentBindings)); +} + +_MTL_INLINE MTL::RenderPipelineReflection* MTL::RenderPipelineReflection::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::meshBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshBindings)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::objectBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectBindings)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::tileArguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileArguments)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::tileBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileBindings)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::vertexArguments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexArguments)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineReflection::vertexBindings() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBindings)); +} + +_MTL_INLINE MTL::RenderPipelineDescriptor* MTL::RenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineDescriptor)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::alphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::alphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::RenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL::PixelFormat MTL::RenderPipelineDescriptor::depthAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachmentPixelFormat)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::RenderPipelineDescriptor::fragmentBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::RenderPipelineDescriptor::fragmentFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::RenderPipelineDescriptor::fragmentLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentLinkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineDescriptor::fragmentPreloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentPreloadedLibraries)); +} + +_MTL_INLINE MTL::RenderPipelineDescriptor* MTL::RenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PrimitiveTopologyClass MTL::RenderPipelineDescriptor::inputPrimitiveTopology() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(inputPrimitiveTopology)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isAlphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isAlphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::isTessellationFactorScaleEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isTessellationFactorScaleEnabled)); +} + +_MTL_INLINE NS::String* MTL::RenderPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxFragmentCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxFragmentCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxTessellationFactor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTessellationFactor)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::maxVertexCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setAlphaToCoverageEnabled(bool alphaToCoverageEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageEnabled_), alphaToCoverageEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setAlphaToOneEnabled(bool alphaToOneEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneEnabled_), alphaToOneEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachmentPixelFormat_), depthAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setFragmentFunction(const MTL::Function* fragmentFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunction_), fragmentFunction); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentLinkedFunctions_), fragmentLinkedFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setFragmentPreloadedLibraries(const NS::Array* fragmentPreloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentPreloadedLibraries_), fragmentPreloadedLibraries); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setInputPrimitiveTopology(MTL::PrimitiveTopologyClass inputPrimitiveTopology) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInputPrimitiveTopology_), inputPrimitiveTopology); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxFragmentCallStackDepth(NS::UInteger maxFragmentCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxFragmentCallStackDepth_), maxFragmentCallStackDepth); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxTessellationFactor(NS::UInteger maxTessellationFactor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTessellationFactor_), maxTessellationFactor); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setMaxVertexCallStackDepth(NS::UInteger maxVertexCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexCallStackDepth_), maxVertexCallStackDepth); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSampleCount(NS::UInteger sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachmentPixelFormat_), stencilAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSupportAddingFragmentBinaryFunctions(bool supportAddingFragmentBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingFragmentBinaryFunctions_), supportAddingFragmentBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSupportAddingVertexBinaryFunctions(bool supportAddingVertexBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingVertexBinaryFunctions_), supportAddingVertexBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationControlPointIndexType(MTL::TessellationControlPointIndexType tessellationControlPointIndexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationControlPointIndexType_), tessellationControlPointIndexType); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationFactorFormat(MTL::TessellationFactorFormat tessellationFactorFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorFormat_), tessellationFactorFormat); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationFactorScaleEnabled(bool tessellationFactorScaleEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorScaleEnabled_), tessellationFactorScaleEnabled); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationFactorStepFunction(MTL::TessellationFactorStepFunction tessellationFactorStepFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationFactorStepFunction_), tessellationFactorStepFunction); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationOutputWindingOrder(MTL::Winding tessellationOutputWindingOrder) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationOutputWindingOrder_), tessellationOutputWindingOrder); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setTessellationPartitionMode(MTL::TessellationPartitionMode tessellationPartitionMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTessellationPartitionMode_), tessellationPartitionMode); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexDescriptor(const MTL::VertexDescriptor* vertexDescriptor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexDescriptor_), vertexDescriptor); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexFunction(const MTL::Function* vertexFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexFunction_), vertexFunction); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexLinkedFunctions(const MTL::LinkedFunctions* vertexLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexLinkedFunctions_), vertexLinkedFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineDescriptor::setVertexPreloadedLibraries(const NS::Array* vertexPreloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexPreloadedLibraries_), vertexPreloadedLibraries); +} + +_MTL_INLINE MTL::ShaderValidation MTL::RenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL::PixelFormat MTL::RenderPipelineDescriptor::stencilAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachmentPixelFormat)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::supportAddingFragmentBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingFragmentBinaryFunctions)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::supportAddingVertexBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingVertexBinaryFunctions)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE MTL::TessellationControlPointIndexType MTL::RenderPipelineDescriptor::tessellationControlPointIndexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationControlPointIndexType)); +} + +_MTL_INLINE MTL::TessellationFactorFormat MTL::RenderPipelineDescriptor::tessellationFactorFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationFactorFormat)); +} + +_MTL_INLINE bool MTL::RenderPipelineDescriptor::tessellationFactorScaleEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isTessellationFactorScaleEnabled)); +} + +_MTL_INLINE MTL::TessellationFactorStepFunction MTL::RenderPipelineDescriptor::tessellationFactorStepFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationFactorStepFunction)); +} + +_MTL_INLINE MTL::Winding MTL::RenderPipelineDescriptor::tessellationOutputWindingOrder() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationOutputWindingOrder)); +} + +_MTL_INLINE MTL::TessellationPartitionMode MTL::RenderPipelineDescriptor::tessellationPartitionMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tessellationPartitionMode)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::RenderPipelineDescriptor::vertexBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexBuffers)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::RenderPipelineDescriptor::vertexDescriptor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexDescriptor)); +} + +_MTL_INLINE MTL::Function* MTL::RenderPipelineDescriptor::vertexFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::RenderPipelineDescriptor::vertexLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexLinkedFunctions)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineDescriptor::vertexPreloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexPreloadedLibraries)); +} + +_MTL_INLINE MTL::RenderPipelineFunctionsDescriptor* MTL::RenderPipelineFunctionsDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineFunctionsDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineFunctionsDescriptor::fragmentAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL::RenderPipelineFunctionsDescriptor* MTL::RenderPipelineFunctionsDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::RenderPipelineFunctionsDescriptor::setFragmentAdditionalBinaryFunctions(const NS::Array* fragmentAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentAdditionalBinaryFunctions_), fragmentAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineFunctionsDescriptor::setTileAdditionalBinaryFunctions(const NS::Array* tileAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileAdditionalBinaryFunctions_), tileAdditionalBinaryFunctions); +} + +_MTL_INLINE void MTL::RenderPipelineFunctionsDescriptor::setVertexAdditionalBinaryFunctions(const NS::Array* vertexAdditionalBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setVertexAdditionalBinaryFunctions_), vertexAdditionalBinaryFunctions); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineFunctionsDescriptor::tileAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileAdditionalBinaryFunctions)); +} + +_MTL_INLINE NS::Array* MTL::RenderPipelineFunctionsDescriptor::vertexAdditionalBinaryFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(vertexAdditionalBinaryFunctions)); +} + +_MTL_INLINE MTL::Device* MTL::RenderPipelineState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::RenderPipelineState::functionHandle(const NS::String* name, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithName_stage_), name, stage); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::RenderPipelineState::functionHandle(const MTL4::BinaryFunction* function, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithBinaryFunction_stage_), function, stage); +} + +_MTL_INLINE MTL::FunctionHandle* MTL::RenderPipelineState::functionHandle(const MTL::Function* function, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionHandleWithFunction_stage_), function, stage); +} + +_MTL_INLINE MTL::ResourceID MTL::RenderPipelineState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::imageblockMemoryLength(MTL::Size imageblockDimensions) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockMemoryLengthForDimensions_), imageblockDimensions); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::imageblockSampleLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(imageblockSampleLength)); +} + +_MTL_INLINE NS::String* MTL::RenderPipelineState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadgroupsPerMeshGrid() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadgroupsPerMeshGrid)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::meshThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshThreadExecutionWidth)); +} + +_MTL_INLINE MTL::IntersectionFunctionTable* MTL::RenderPipelineState::newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newIntersectionFunctionTableWithDescriptor_stage_), descriptor, stage); +} + +_MTL_INLINE MTL4::PipelineDescriptor* MTL::RenderPipelineState::newRenderPipelineDescriptor() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineDescriptorForSpecialization)); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::RenderPipelineState::newRenderPipelineState(const MTL4::RenderPipelineBinaryFunctionsDescriptor* binaryFunctionsDescriptor, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithBinaryFunctions_error_), binaryFunctionsDescriptor, error); +} + +_MTL_INLINE MTL::RenderPipelineState* MTL::RenderPipelineState::newRenderPipelineState(const MTL::RenderPipelineFunctionsDescriptor* additionalBinaryFunctions, NS::Error** error) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRenderPipelineStateWithAdditionalBinaryFunctions_error_), additionalBinaryFunctions, error); +} + +_MTL_INLINE MTL::VisibleFunctionTable* MTL::RenderPipelineState::newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor, MTL::RenderStages stage) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newVisibleFunctionTableWithDescriptor_stage_), descriptor, stage); +} + +_MTL_INLINE NS::UInteger MTL::RenderPipelineState::objectThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectThreadExecutionWidth)); +} + +_MTL_INLINE MTL::RenderPipelineReflection* MTL::RenderPipelineState::reflection() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reflection)); +} + +_MTL_INLINE MTL::Size MTL::RenderPipelineState::requiredThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL::RenderPipelineState::requiredThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL::RenderPipelineState::requiredThreadsPerTileThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerTileThreadgroup)); +} + +_MTL_INLINE MTL::ShaderValidation MTL::RenderPipelineState::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE bool MTL::RenderPipelineState::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} + +_MTL_INLINE bool MTL::RenderPipelineState::threadgroupSizeMatchesTileSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupSizeMatchesTileSize)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::RenderPipelineColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLRenderPipelineColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::RenderPipelineColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptor* MTL::RenderPipelineColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::RenderPipelineColorAttachmentDescriptorArray::setObject(const MTL::RenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptor* MTL::TileRenderPipelineColorAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTileRenderPipelineColorAttachmentDescriptor)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptor* MTL::TileRenderPipelineColorAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::PixelFormat MTL::TileRenderPipelineColorAttachmentDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE void MTL::TileRenderPipelineColorAttachmentDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL::TileRenderPipelineColorAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTileRenderPipelineColorAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL::TileRenderPipelineColorAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptor* MTL::TileRenderPipelineColorAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::TileRenderPipelineColorAttachmentDescriptorArray::setObject(const MTL::TileRenderPipelineColorAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::TileRenderPipelineDescriptor* MTL::TileRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTileRenderPipelineDescriptor)); +} + +_MTL_INLINE NS::Array* MTL::TileRenderPipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::TileRenderPipelineColorAttachmentDescriptorArray* MTL::TileRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL::TileRenderPipelineDescriptor* MTL::TileRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::TileRenderPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::TileRenderPipelineDescriptor::linkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(linkedFunctions)); +} + +_MTL_INLINE NS::UInteger MTL::TileRenderPipelineDescriptor::maxCallStackDepth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxCallStackDepth)); +} + +_MTL_INLINE NS::UInteger MTL::TileRenderPipelineDescriptor::maxTotalThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup)); +} + +_MTL_INLINE NS::Array* MTL::TileRenderPipelineDescriptor::preloadedLibraries() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(preloadedLibraries)); +} + +_MTL_INLINE NS::UInteger MTL::TileRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE MTL::Size MTL::TileRenderPipelineDescriptor::requiredThreadsPerThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup)); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLinkedFunctions_), linkedFunctions); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setMaxCallStackDepth(NS::UInteger maxCallStackDepth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setPreloadedLibraries(const NS::Array* preloadedLibraries) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPreloadedLibraries_), preloadedLibraries); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportAddingBinaryFunctions_), supportAddingBinaryFunctions); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setThreadgroupSizeMatchesTileSize(bool threadgroupSizeMatchesTileSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setThreadgroupSizeMatchesTileSize_), threadgroupSizeMatchesTileSize); +} + +_MTL_INLINE void MTL::TileRenderPipelineDescriptor::setTileFunction(const MTL::Function* tileFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTileFunction_), tileFunction); +} + +_MTL_INLINE MTL::ShaderValidation MTL::TileRenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE bool MTL::TileRenderPipelineDescriptor::supportAddingBinaryFunctions() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportAddingBinaryFunctions)); +} + +_MTL_INLINE bool MTL::TileRenderPipelineDescriptor::threadgroupSizeMatchesTileSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(threadgroupSizeMatchesTileSize)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::TileRenderPipelineDescriptor::tileBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::TileRenderPipelineDescriptor::tileFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tileFunction)); +} + +_MTL_INLINE MTL::MeshRenderPipelineDescriptor* MTL::MeshRenderPipelineDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLMeshRenderPipelineDescriptor)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::alphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::alphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE NS::Array* MTL::MeshRenderPipelineDescriptor::binaryArchives() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(binaryArchives)); +} + +_MTL_INLINE MTL::RenderPipelineColorAttachmentDescriptorArray* MTL::MeshRenderPipelineDescriptor::colorAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(colorAttachments)); +} + +_MTL_INLINE MTL::PixelFormat MTL::MeshRenderPipelineDescriptor::depthAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depthAttachmentPixelFormat)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::MeshRenderPipelineDescriptor::fragmentBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::MeshRenderPipelineDescriptor::fragmentFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::MeshRenderPipelineDescriptor::fragmentLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(fragmentLinkedFunctions)); +} + +_MTL_INLINE MTL::MeshRenderPipelineDescriptor* MTL::MeshRenderPipelineDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::isAlphaToCoverageEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToCoverageEnabled)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::isAlphaToOneEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAlphaToOneEnabled)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::isRasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE NS::String* MTL::MeshRenderPipelineDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxTotalThreadgroupsPerMeshGrid() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadgroupsPerMeshGrid)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxTotalThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxTotalThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::maxVertexAmplificationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxVertexAmplificationCount)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::MeshRenderPipelineDescriptor::meshBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::MeshRenderPipelineDescriptor::meshFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::MeshRenderPipelineDescriptor::meshLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshLinkedFunctions)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::meshThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(meshThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::MeshRenderPipelineDescriptor::objectBuffers() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectBuffers)); +} + +_MTL_INLINE MTL::Function* MTL::MeshRenderPipelineDescriptor::objectFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectFunction)); +} + +_MTL_INLINE MTL::LinkedFunctions* MTL::MeshRenderPipelineDescriptor::objectLinkedFunctions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectLinkedFunctions)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::objectThreadgroupSizeIsMultipleOfThreadExecutionWidth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectThreadgroupSizeIsMultipleOfThreadExecutionWidth)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::payloadMemoryLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(payloadMemoryLength)); +} + +_MTL_INLINE NS::UInteger MTL::MeshRenderPipelineDescriptor::rasterSampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rasterSampleCount)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::rasterizationEnabled() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isRasterizationEnabled)); +} + +_MTL_INLINE MTL::Size MTL::MeshRenderPipelineDescriptor::requiredThreadsPerMeshThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerMeshThreadgroup)); +} + +_MTL_INLINE MTL::Size MTL::MeshRenderPipelineDescriptor::requiredThreadsPerObjectThreadgroup() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(requiredThreadsPerObjectThreadgroup)); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setAlphaToCoverageEnabled(bool alphaToCoverageEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToCoverageEnabled_), alphaToCoverageEnabled); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setAlphaToOneEnabled(bool alphaToOneEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAlphaToOneEnabled_), alphaToOneEnabled); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setDepthAttachmentPixelFormat(MTL::PixelFormat depthAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepthAttachmentPixelFormat_), depthAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setFragmentFunction(const MTL::Function* fragmentFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentFunction_), fragmentFunction); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setFragmentLinkedFunctions(const MTL::LinkedFunctions* fragmentLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFragmentLinkedFunctions_), fragmentLinkedFunctions); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxTotalThreadgroupsPerMeshGrid(NS::UInteger maxTotalThreadgroupsPerMeshGrid) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadgroupsPerMeshGrid_), maxTotalThreadgroupsPerMeshGrid); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerMeshThreadgroup(NS::UInteger maxTotalThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerMeshThreadgroup_), maxTotalThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxTotalThreadsPerObjectThreadgroup(NS::UInteger maxTotalThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerObjectThreadgroup_), maxTotalThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMaxVertexAmplificationCount(NS::UInteger maxVertexAmplificationCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxVertexAmplificationCount_), maxVertexAmplificationCount); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMeshFunction(const MTL::Function* meshFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshFunction_), meshFunction); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMeshLinkedFunctions(const MTL::LinkedFunctions* meshLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshLinkedFunctions_), meshLinkedFunctions); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool meshThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMeshThreadgroupSizeIsMultipleOfThreadExecutionWidth_), meshThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setObjectFunction(const MTL::Function* objectFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectFunction_), objectFunction); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setObjectLinkedFunctions(const MTL::LinkedFunctions* objectLinkedFunctions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectLinkedFunctions_), objectLinkedFunctions); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth(bool objectThreadgroupSizeIsMultipleOfThreadExecutionWidth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObjectThreadgroupSizeIsMultipleOfThreadExecutionWidth_), objectThreadgroupSizeIsMultipleOfThreadExecutionWidth); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setPayloadMemoryLength(NS::UInteger payloadMemoryLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPayloadMemoryLength_), payloadMemoryLength); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRasterSampleCount(NS::UInteger rasterSampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterSampleCount_), rasterSampleCount); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRasterizationEnabled(bool rasterizationEnabled) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRasterizationEnabled_), rasterizationEnabled); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRequiredThreadsPerMeshThreadgroup(MTL::Size requiredThreadsPerMeshThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerMeshThreadgroup_), requiredThreadsPerMeshThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setRequiredThreadsPerObjectThreadgroup(MTL::Size requiredThreadsPerObjectThreadgroup) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerObjectThreadgroup_), requiredThreadsPerObjectThreadgroup); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setStencilAttachmentPixelFormat(MTL::PixelFormat stencilAttachmentPixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStencilAttachmentPixelFormat_), stencilAttachmentPixelFormat); +} + +_MTL_INLINE void MTL::MeshRenderPipelineDescriptor::setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers); +} + +_MTL_INLINE MTL::ShaderValidation MTL::MeshRenderPipelineDescriptor::shaderValidation() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(shaderValidation)); +} + +_MTL_INLINE MTL::PixelFormat MTL::MeshRenderPipelineDescriptor::stencilAttachmentPixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stencilAttachmentPixelFormat)); +} + +_MTL_INLINE bool MTL::MeshRenderPipelineDescriptor::supportIndirectCommandBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers)); +} diff --git a/dist/include/metal_cpp/Metal/MTLResidencySet.hpp b/dist/include/metal_cpp/Metal/MTLResidencySet.hpp new file mode 100644 index 0000000..d073972 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLResidencySet.hpp @@ -0,0 +1,178 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResidencySet.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Allocation; +class Device; +class ResidencySetDescriptor; + +class ResidencySetDescriptor : public NS::Copying +{ +public: + static ResidencySetDescriptor* alloc(); + + ResidencySetDescriptor* init(); + NS::UInteger initialCapacity() const; + + NS::String* label() const; + + void setInitialCapacity(NS::UInteger initialCapacity); + + void setLabel(const NS::String* label); +}; +class ResidencySet : public NS::Referencing +{ +public: + void addAllocation(const MTL::Allocation* allocation); + void addAllocations(const MTL::Allocation* const allocations[], NS::UInteger count); + + NS::Array* allAllocations() const; + + uint64_t allocatedSize() const; + + NS::UInteger allocationCount() const; + + void commit(); + + bool containsAllocation(const MTL::Allocation* anAllocation); + + Device* device() const; + + void endResidency(); + + NS::String* label() const; + + void removeAllAllocations(); + + void removeAllocation(const MTL::Allocation* allocation); + void removeAllocations(const MTL::Allocation* const allocations[], NS::UInteger count); + + void requestResidency(); +}; + +} +_MTL_INLINE MTL::ResidencySetDescriptor* MTL::ResidencySetDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResidencySetDescriptor)); +} + +_MTL_INLINE MTL::ResidencySetDescriptor* MTL::ResidencySetDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::ResidencySetDescriptor::initialCapacity() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initialCapacity)); +} + +_MTL_INLINE NS::String* MTL::ResidencySetDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::ResidencySetDescriptor::setInitialCapacity(NS::UInteger initialCapacity) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setInitialCapacity_), initialCapacity); +} + +_MTL_INLINE void MTL::ResidencySetDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ResidencySet::addAllocation(const MTL::Allocation* allocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addAllocation_), allocation); +} + +_MTL_INLINE void MTL::ResidencySet::addAllocations(const MTL::Allocation* const allocations[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(addAllocations_count_), allocations, count); +} + +_MTL_INLINE NS::Array* MTL::ResidencySet::allAllocations() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allAllocations)); +} + +_MTL_INLINE uint64_t MTL::ResidencySet::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} + +_MTL_INLINE NS::UInteger MTL::ResidencySet::allocationCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocationCount)); +} + +_MTL_INLINE void MTL::ResidencySet::commit() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(commit)); +} + +_MTL_INLINE bool MTL::ResidencySet::containsAllocation(const MTL::Allocation* anAllocation) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(containsAllocation_), anAllocation); +} + +_MTL_INLINE MTL::Device* MTL::ResidencySet::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE void MTL::ResidencySet::endResidency() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(endResidency)); +} + +_MTL_INLINE NS::String* MTL::ResidencySet::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::ResidencySet::removeAllAllocations() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllAllocations)); +} + +_MTL_INLINE void MTL::ResidencySet::removeAllocation(const MTL::Allocation* allocation) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllocation_), allocation); +} + +_MTL_INLINE void MTL::ResidencySet::removeAllocations(const MTL::Allocation* const allocations[], NS::UInteger count) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(removeAllocations_count_), allocations, count); +} + +_MTL_INLINE void MTL::ResidencySet::requestResidency() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(requestResidency)); +} diff --git a/dist/include/metal_cpp/Metal/MTLResource.hpp b/dist/include/metal_cpp/Metal/MTLResource.hpp new file mode 100644 index 0000000..21e49bb --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLResource.hpp @@ -0,0 +1,190 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResource.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLAllocation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include + +namespace MTL +{ +class Device; +class Heap; +_MTL_ENUM(NS::UInteger, PurgeableState) { + PurgeableStateKeepCurrent = 1, + PurgeableStateNonVolatile = 2, + PurgeableStateVolatile = 3, + PurgeableStateEmpty = 4, +}; + +_MTL_ENUM(NS::UInteger, CPUCacheMode) { + CPUCacheModeDefaultCache = 0, + CPUCacheModeWriteCombined = 1, +}; + +_MTL_ENUM(NS::UInteger, StorageMode) { + StorageModeShared = 0, + StorageModeManaged = 1, + StorageModePrivate = 2, + StorageModeMemoryless = 3, +}; + +_MTL_ENUM(NS::UInteger, HazardTrackingMode) { + HazardTrackingModeDefault = 0, + HazardTrackingModeUntracked = 1, + HazardTrackingModeTracked = 2, +}; + +_MTL_ENUM(NS::Integer, SparsePageSize) { + SparsePageSize16 = 101, + SparsePageSize64 = 102, + SparsePageSize256 = 103, +}; + +_MTL_ENUM(NS::Integer, BufferSparseTier) { + BufferSparseTierNone = 0, + BufferSparseTier1 = 1, +}; + +_MTL_ENUM(NS::Integer, TextureSparseTier) { + TextureSparseTierNone = 0, + TextureSparseTier1 = 1, + TextureSparseTier2 = 2, +}; + +_MTL_OPTIONS(NS::UInteger, ResourceOptions) { + ResourceCPUCacheModeDefaultCache = 0, + ResourceCPUCacheModeWriteCombined = 1, + ResourceStorageModeShared = 0, + ResourceStorageModeManaged = 1 << 4, + ResourceStorageModePrivate = 1 << 5, + ResourceStorageModeMemoryless = 1 << 5, + ResourceHazardTrackingModeDefault = 0, + ResourceHazardTrackingModeUntracked = 1 << 8, + ResourceHazardTrackingModeTracked = 1 << 9, + ResourceOptionCPUCacheModeDefault = 0, + ResourceOptionCPUCacheModeWriteCombined = 1, +}; + +class Resource : public NS::Referencing +{ +public: + NS::UInteger allocatedSize() const; + + CPUCacheMode cpuCacheMode() const; + + Device* device() const; + + HazardTrackingMode hazardTrackingMode() const; + + Heap* heap() const; + NS::UInteger heapOffset() const; + + bool isAliasable(); + + NS::String* label() const; + + void makeAliasable(); + + ResourceOptions resourceOptions() const; + + void setLabel(const NS::String* label); + + kern_return_t setOwner(task_id_token_t task_id_token); + + PurgeableState setPurgeableState(MTL::PurgeableState state); + + StorageMode storageMode() const; +}; + +} +_MTL_INLINE NS::UInteger MTL::Resource::allocatedSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allocatedSize)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::Resource::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE MTL::Device* MTL::Resource::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::Resource::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE MTL::Heap* MTL::Resource::heap() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heap)); +} + +_MTL_INLINE NS::UInteger MTL::Resource::heapOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(heapOffset)); +} + +_MTL_INLINE bool MTL::Resource::isAliasable() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isAliasable)); +} + +_MTL_INLINE NS::String* MTL::Resource::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE void MTL::Resource::makeAliasable() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(makeAliasable)); +} + +_MTL_INLINE MTL::ResourceOptions MTL::Resource::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::Resource::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE kern_return_t MTL::Resource::setOwner(task_id_token_t task_id_token) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setOwnerWithIdentity_), task_id_token); +} + +_MTL_INLINE MTL::PurgeableState MTL::Resource::setPurgeableState(MTL::PurgeableState state) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setPurgeableState_), state); +} + +_MTL_INLINE MTL::StorageMode MTL::Resource::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} diff --git a/dist/include/metal_cpp/Metal/MTLResourceStateCommandEncoder.hpp b/dist/include/metal_cpp/Metal/MTLResourceStateCommandEncoder.hpp new file mode 100644 index 0000000..3f565c3 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLResourceStateCommandEncoder.hpp @@ -0,0 +1,98 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceStateCommandEncoder.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class Fence; +struct Region; +class Texture; +_MTL_ENUM(NS::UInteger, SparseTextureMappingMode) { + SparseTextureMappingModeMap = 0, + SparseTextureMappingModeUnmap = 1, +}; + +struct MapIndirectArguments +{ + uint32_t regionOriginX; + uint32_t regionOriginY; + uint32_t regionOriginZ; + uint32_t regionSizeWidth; + uint32_t regionSizeHeight; + uint32_t regionSizeDepth; + uint32_t mipMapLevel; + uint32_t sliceId; +} _MTL_PACKED; + +class ResourceStateCommandEncoder : public NS::Referencing +{ +public: + void moveTextureMappingsFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin); + + void updateFence(const MTL::Fence* fence); + + void updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region region, const NS::UInteger mipLevel, const NS::UInteger slice); + void updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset); + void updateTextureMappings(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region* regions, const NS::UInteger* mipLevels, const NS::UInteger* slices, NS::UInteger numRegions); + + void waitForFence(const MTL::Fence* fence); +}; + +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::moveTextureMappingsFromTexture(const MTL::Texture* sourceTexture, NS::UInteger sourceSlice, NS::UInteger sourceLevel, MTL::Origin sourceOrigin, MTL::Size sourceSize, const MTL::Texture* destinationTexture, NS::UInteger destinationSlice, NS::UInteger destinationLevel, MTL::Origin destinationOrigin) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(moveTextureMappingsFromTexture_sourceSlice_sourceLevel_sourceOrigin_sourceSize_toTexture_destinationSlice_destinationLevel_destinationOrigin_), sourceTexture, sourceSlice, sourceLevel, sourceOrigin, sourceSize, destinationTexture, destinationSlice, destinationLevel, destinationOrigin); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateFence_), fence); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region region, const NS::UInteger mipLevel, const NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMapping_mode_region_mipLevel_slice_), texture, mode, region, mipLevel, slice); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateTextureMapping(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Buffer* indirectBuffer, NS::UInteger indirectBufferOffset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMapping_mode_indirectBuffer_indirectBufferOffset_), texture, mode, indirectBuffer, indirectBufferOffset); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::updateTextureMappings(const MTL::Texture* texture, const MTL::SparseTextureMappingMode mode, const MTL::Region* regions, const NS::UInteger* mipLevels, const NS::UInteger* slices, NS::UInteger numRegions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(updateTextureMappings_mode_regions_mipLevels_slices_numRegions_), texture, mode, regions, mipLevels, slices, numRegions); +} + +_MTL_INLINE void MTL::ResourceStateCommandEncoder::waitForFence(const MTL::Fence* fence) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(waitForFence_), fence); +} diff --git a/dist/include/metal_cpp/Metal/MTLResourceStatePass.hpp b/dist/include/metal_cpp/Metal/MTLResourceStatePass.hpp new file mode 100644 index 0000000..f368901 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLResourceStatePass.hpp @@ -0,0 +1,154 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceStatePass.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class CounterSampleBuffer; +class ResourceStatePassDescriptor; +class ResourceStatePassSampleBufferAttachmentDescriptor; +class ResourceStatePassSampleBufferAttachmentDescriptorArray; + +class ResourceStatePassSampleBufferAttachmentDescriptor : public NS::Copying +{ +public: + static ResourceStatePassSampleBufferAttachmentDescriptor* alloc(); + + NS::UInteger endOfEncoderSampleIndex() const; + + ResourceStatePassSampleBufferAttachmentDescriptor* init(); + + CounterSampleBuffer* sampleBuffer() const; + + void setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex); + + void setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer); + + void setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex); + NS::UInteger startOfEncoderSampleIndex() const; +}; +class ResourceStatePassSampleBufferAttachmentDescriptorArray : public NS::Referencing +{ +public: + static ResourceStatePassSampleBufferAttachmentDescriptorArray* alloc(); + + ResourceStatePassSampleBufferAttachmentDescriptorArray* init(); + + ResourceStatePassSampleBufferAttachmentDescriptor* object(NS::UInteger attachmentIndex); + void setObject(const MTL::ResourceStatePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex); +}; +class ResourceStatePassDescriptor : public NS::Copying +{ +public: + static ResourceStatePassDescriptor* alloc(); + + ResourceStatePassDescriptor* init(); + + static ResourceStatePassDescriptor* resourceStatePassDescriptor(); + + ResourceStatePassSampleBufferAttachmentDescriptorArray* sampleBufferAttachments() const; +}; + +} +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptor* MTL::ResourceStatePassSampleBufferAttachmentDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::ResourceStatePassSampleBufferAttachmentDescriptor::endOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(endOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptor* MTL::ResourceStatePassSampleBufferAttachmentDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::CounterSampleBuffer* MTL::ResourceStatePassSampleBufferAttachmentDescriptor::sampleBuffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBuffer)); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptor::setEndOfEncoderSampleIndex(NS::UInteger endOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setEndOfEncoderSampleIndex_), endOfEncoderSampleIndex); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptor::setSampleBuffer(const MTL::CounterSampleBuffer* sampleBuffer) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleBuffer_), sampleBuffer); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptor::setStartOfEncoderSampleIndex(NS::UInteger startOfEncoderSampleIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStartOfEncoderSampleIndex_), startOfEncoderSampleIndex); +} + +_MTL_INLINE NS::UInteger MTL::ResourceStatePassSampleBufferAttachmentDescriptor::startOfEncoderSampleIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(startOfEncoderSampleIndex)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray* MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceStatePassSampleBufferAttachmentDescriptorArray)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray* MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptor* MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::object(NS::UInteger attachmentIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), attachmentIndex); +} + +_MTL_INLINE void MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray::setObject(const MTL::ResourceStatePassSampleBufferAttachmentDescriptor* attachment, NS::UInteger attachmentIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attachment, attachmentIndex); +} + +_MTL_INLINE MTL::ResourceStatePassDescriptor* MTL::ResourceStatePassDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceStatePassDescriptor)); +} + +_MTL_INLINE MTL::ResourceStatePassDescriptor* MTL::ResourceStatePassDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ResourceStatePassDescriptor* MTL::ResourceStatePassDescriptor::resourceStatePassDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLResourceStatePassDescriptor), _MTL_PRIVATE_SEL(resourceStatePassDescriptor)); +} + +_MTL_INLINE MTL::ResourceStatePassSampleBufferAttachmentDescriptorArray* MTL::ResourceStatePassDescriptor::sampleBufferAttachments() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleBufferAttachments)); +} diff --git a/dist/include/metal_cpp/Metal/MTLResourceViewPool.hpp b/dist/include/metal_cpp/Metal/MTLResourceViewPool.hpp new file mode 100644 index 0000000..aa8bfda --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLResourceViewPool.hpp @@ -0,0 +1,118 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLResourceViewPool.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +class ResourceViewPool; +class ResourceViewPoolDescriptor; + +class ResourceViewPoolDescriptor : public NS::Copying +{ +public: + static ResourceViewPoolDescriptor* alloc(); + + ResourceViewPoolDescriptor* init(); + + NS::String* label() const; + + NS::UInteger resourceViewCount() const; + + void setLabel(const NS::String* label); + + void setResourceViewCount(NS::UInteger resourceViewCount); +}; +class ResourceViewPool : public NS::Referencing +{ +public: + ResourceID baseResourceID() const; + + ResourceID copyResourceViewsFromPool(const MTL::ResourceViewPool* sourcePool, NS::Range sourceRange, NS::UInteger destinationIndex); + + Device* device() const; + + NS::String* label() const; + + NS::UInteger resourceViewCount() const; +}; + +} +_MTL_INLINE MTL::ResourceViewPoolDescriptor* MTL::ResourceViewPoolDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLResourceViewPoolDescriptor)); +} + +_MTL_INLINE MTL::ResourceViewPoolDescriptor* MTL::ResourceViewPoolDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::ResourceViewPoolDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::ResourceViewPoolDescriptor::resourceViewCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceViewCount)); +} + +_MTL_INLINE void MTL::ResourceViewPoolDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::ResourceViewPoolDescriptor::setResourceViewCount(NS::UInteger resourceViewCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceViewCount_), resourceViewCount); +} + +_MTL_INLINE MTL::ResourceID MTL::ResourceViewPool::baseResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(baseResourceID)); +} + +_MTL_INLINE MTL::ResourceID MTL::ResourceViewPool::copyResourceViewsFromPool(const MTL::ResourceViewPool* sourcePool, NS::Range sourceRange, NS::UInteger destinationIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(copyResourceViewsFromPool_sourceRange_destinationIndex_), sourcePool, sourceRange, destinationIndex); +} + +_MTL_INLINE MTL::Device* MTL::ResourceViewPool::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE NS::String* MTL::ResourceViewPool::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE NS::UInteger MTL::ResourceViewPool::resourceViewCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceViewCount)); +} diff --git a/dist/include/metal_cpp/Metal/MTLSampler.hpp b/dist/include/metal_cpp/Metal/MTLSampler.hpp new file mode 100644 index 0000000..f228665 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLSampler.hpp @@ -0,0 +1,345 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLSampler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLDepthStencil.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Device; +class SamplerDescriptor; +_MTL_ENUM(NS::UInteger, SamplerMinMagFilter) { + SamplerMinMagFilterNearest = 0, + SamplerMinMagFilterLinear = 1, +}; + +_MTL_ENUM(NS::UInteger, SamplerMipFilter) { + SamplerMipFilterNotMipmapped = 0, + SamplerMipFilterNearest = 1, + SamplerMipFilterLinear = 2, +}; + +_MTL_ENUM(NS::UInteger, SamplerAddressMode) { + SamplerAddressModeClampToEdge = 0, + SamplerAddressModeMirrorClampToEdge = 1, + SamplerAddressModeRepeat = 2, + SamplerAddressModeMirrorRepeat = 3, + SamplerAddressModeClampToZero = 4, + SamplerAddressModeClampToBorderColor = 5, +}; + +_MTL_ENUM(NS::UInteger, SamplerBorderColor) { + SamplerBorderColorTransparentBlack = 0, + SamplerBorderColorOpaqueBlack = 1, + SamplerBorderColorOpaqueWhite = 2, +}; + +_MTL_ENUM(NS::UInteger, SamplerReductionMode) { + SamplerReductionModeWeightedAverage = 0, + SamplerReductionModeMinimum = 1, + SamplerReductionModeMaximum = 2, +}; + +class SamplerDescriptor : public NS::Copying +{ +public: + static SamplerDescriptor* alloc(); + + SamplerBorderColor borderColor() const; + + CompareFunction compareFunction() const; + + SamplerDescriptor* init(); + + NS::String* label() const; + + bool lodAverage() const; + + float lodBias() const; + + float lodMaxClamp() const; + + float lodMinClamp() const; + + SamplerMinMagFilter magFilter() const; + + NS::UInteger maxAnisotropy() const; + + SamplerMinMagFilter minFilter() const; + + SamplerMipFilter mipFilter() const; + + bool normalizedCoordinates() const; + + SamplerAddressMode rAddressMode() const; + + SamplerReductionMode reductionMode() const; + + SamplerAddressMode sAddressMode() const; + + void setBorderColor(MTL::SamplerBorderColor borderColor); + + void setCompareFunction(MTL::CompareFunction compareFunction); + + void setLabel(const NS::String* label); + + void setLodAverage(bool lodAverage); + + void setLodBias(float lodBias); + + void setLodMaxClamp(float lodMaxClamp); + + void setLodMinClamp(float lodMinClamp); + + void setMagFilter(MTL::SamplerMinMagFilter magFilter); + + void setMaxAnisotropy(NS::UInteger maxAnisotropy); + + void setMinFilter(MTL::SamplerMinMagFilter minFilter); + + void setMipFilter(MTL::SamplerMipFilter mipFilter); + + void setNormalizedCoordinates(bool normalizedCoordinates); + + void setRAddressMode(MTL::SamplerAddressMode rAddressMode); + + void setReductionMode(MTL::SamplerReductionMode reductionMode); + + void setSAddressMode(MTL::SamplerAddressMode sAddressMode); + + void setSupportArgumentBuffers(bool supportArgumentBuffers); + + void setTAddressMode(MTL::SamplerAddressMode tAddressMode); + + bool supportArgumentBuffers() const; + + SamplerAddressMode tAddressMode() const; +}; +class SamplerState : public NS::Referencing +{ +public: + Device* device() const; + + ResourceID gpuResourceID() const; + + NS::String* label() const; +}; + +} +_MTL_INLINE MTL::SamplerDescriptor* MTL::SamplerDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSamplerDescriptor)); +} + +_MTL_INLINE MTL::SamplerBorderColor MTL::SamplerDescriptor::borderColor() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(borderColor)); +} + +_MTL_INLINE MTL::CompareFunction MTL::SamplerDescriptor::compareFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compareFunction)); +} + +_MTL_INLINE MTL::SamplerDescriptor* MTL::SamplerDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::SamplerDescriptor::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE bool MTL::SamplerDescriptor::lodAverage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodAverage)); +} + +_MTL_INLINE float MTL::SamplerDescriptor::lodBias() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodBias)); +} + +_MTL_INLINE float MTL::SamplerDescriptor::lodMaxClamp() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodMaxClamp)); +} + +_MTL_INLINE float MTL::SamplerDescriptor::lodMinClamp() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(lodMinClamp)); +} + +_MTL_INLINE MTL::SamplerMinMagFilter MTL::SamplerDescriptor::magFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(magFilter)); +} + +_MTL_INLINE NS::UInteger MTL::SamplerDescriptor::maxAnisotropy() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(maxAnisotropy)); +} + +_MTL_INLINE MTL::SamplerMinMagFilter MTL::SamplerDescriptor::minFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(minFilter)); +} + +_MTL_INLINE MTL::SamplerMipFilter MTL::SamplerDescriptor::mipFilter() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mipFilter)); +} + +_MTL_INLINE bool MTL::SamplerDescriptor::normalizedCoordinates() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(normalizedCoordinates)); +} + +_MTL_INLINE MTL::SamplerAddressMode MTL::SamplerDescriptor::rAddressMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rAddressMode)); +} + +_MTL_INLINE MTL::SamplerReductionMode MTL::SamplerDescriptor::reductionMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(reductionMode)); +} + +_MTL_INLINE MTL::SamplerAddressMode MTL::SamplerDescriptor::sAddressMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sAddressMode)); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setBorderColor(MTL::SamplerBorderColor borderColor) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBorderColor_), borderColor); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setCompareFunction(MTL::CompareFunction compareFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCompareFunction_), compareFunction); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLabel(const NS::String* label) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLabel_), label); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodAverage(bool lodAverage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodAverage_), lodAverage); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodBias(float lodBias) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodBias_), lodBias); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodMaxClamp(float lodMaxClamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodMaxClamp_), lodMaxClamp); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setLodMinClamp(float lodMinClamp) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLodMinClamp_), lodMinClamp); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMagFilter(MTL::SamplerMinMagFilter magFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMagFilter_), magFilter); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMaxAnisotropy(NS::UInteger maxAnisotropy) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMaxAnisotropy_), maxAnisotropy); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMinFilter(MTL::SamplerMinMagFilter minFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMinFilter_), minFilter); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setMipFilter(MTL::SamplerMipFilter mipFilter) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMipFilter_), mipFilter); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setNormalizedCoordinates(bool normalizedCoordinates) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setNormalizedCoordinates_), normalizedCoordinates); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setRAddressMode(MTL::SamplerAddressMode rAddressMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setRAddressMode_), rAddressMode); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setReductionMode(MTL::SamplerReductionMode reductionMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setReductionMode_), reductionMode); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setSAddressMode(MTL::SamplerAddressMode sAddressMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSAddressMode_), sAddressMode); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setSupportArgumentBuffers(bool supportArgumentBuffers) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSupportArgumentBuffers_), supportArgumentBuffers); +} + +_MTL_INLINE void MTL::SamplerDescriptor::setTAddressMode(MTL::SamplerAddressMode tAddressMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTAddressMode_), tAddressMode); +} + +_MTL_INLINE bool MTL::SamplerDescriptor::supportArgumentBuffers() const +{ + return Object::sendMessageSafe(this, _MTL_PRIVATE_SEL(supportArgumentBuffers)); +} + +_MTL_INLINE MTL::SamplerAddressMode MTL::SamplerDescriptor::tAddressMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tAddressMode)); +} + +_MTL_INLINE MTL::Device* MTL::SamplerState::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::ResourceID MTL::SamplerState::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::String* MTL::SamplerState::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} diff --git a/dist/include/metal_cpp/Metal/MTLStageInputOutputDescriptor.hpp b/dist/include/metal_cpp/Metal/MTLStageInputOutputDescriptor.hpp new file mode 100644 index 0000000..b9a7a48 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLStageInputOutputDescriptor.hpp @@ -0,0 +1,356 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLStageInputOutputDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLArgument.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class AttributeDescriptor; +class AttributeDescriptorArray; +class BufferLayoutDescriptor; +class BufferLayoutDescriptorArray; +class StageInputOutputDescriptor; +_MTL_ENUM(NS::UInteger, AttributeFormat) { + AttributeFormatInvalid = 0, + AttributeFormatUChar2 = 1, + AttributeFormatUChar3 = 2, + AttributeFormatUChar4 = 3, + AttributeFormatChar2 = 4, + AttributeFormatChar3 = 5, + AttributeFormatChar4 = 6, + AttributeFormatUChar2Normalized = 7, + AttributeFormatUChar3Normalized = 8, + AttributeFormatUChar4Normalized = 9, + AttributeFormatChar2Normalized = 10, + AttributeFormatChar3Normalized = 11, + AttributeFormatChar4Normalized = 12, + AttributeFormatUShort2 = 13, + AttributeFormatUShort3 = 14, + AttributeFormatUShort4 = 15, + AttributeFormatShort2 = 16, + AttributeFormatShort3 = 17, + AttributeFormatShort4 = 18, + AttributeFormatUShort2Normalized = 19, + AttributeFormatUShort3Normalized = 20, + AttributeFormatUShort4Normalized = 21, + AttributeFormatShort2Normalized = 22, + AttributeFormatShort3Normalized = 23, + AttributeFormatShort4Normalized = 24, + AttributeFormatHalf2 = 25, + AttributeFormatHalf3 = 26, + AttributeFormatHalf4 = 27, + AttributeFormatFloat = 28, + AttributeFormatFloat2 = 29, + AttributeFormatFloat3 = 30, + AttributeFormatFloat4 = 31, + AttributeFormatInt = 32, + AttributeFormatInt2 = 33, + AttributeFormatInt3 = 34, + AttributeFormatInt4 = 35, + AttributeFormatUInt = 36, + AttributeFormatUInt2 = 37, + AttributeFormatUInt3 = 38, + AttributeFormatUInt4 = 39, + AttributeFormatInt1010102Normalized = 40, + AttributeFormatUInt1010102Normalized = 41, + AttributeFormatUChar4Normalized_BGRA = 42, + AttributeFormatUChar = 45, + AttributeFormatChar = 46, + AttributeFormatUCharNormalized = 47, + AttributeFormatCharNormalized = 48, + AttributeFormatUShort = 49, + AttributeFormatShort = 50, + AttributeFormatUShortNormalized = 51, + AttributeFormatShortNormalized = 52, + AttributeFormatHalf = 53, + AttributeFormatFloatRG11B10 = 54, + AttributeFormatFloatRGB9E5 = 55, +}; + +_MTL_ENUM(NS::UInteger, StepFunction) { + StepFunctionConstant = 0, + StepFunctionPerVertex = 1, + StepFunctionPerInstance = 2, + StepFunctionPerPatch = 3, + StepFunctionPerPatchControlPoint = 4, + StepFunctionThreadPositionInGridX = 5, + StepFunctionThreadPositionInGridY = 6, + StepFunctionThreadPositionInGridXIndexed = 7, + StepFunctionThreadPositionInGridYIndexed = 8, +}; + +class BufferLayoutDescriptor : public NS::Copying +{ +public: + static BufferLayoutDescriptor* alloc(); + + BufferLayoutDescriptor* init(); + + void setStepFunction(MTL::StepFunction stepFunction); + + void setStepRate(NS::UInteger stepRate); + + void setStride(NS::UInteger stride); + + StepFunction stepFunction() const; + + NS::UInteger stepRate() const; + + NS::UInteger stride() const; +}; +class BufferLayoutDescriptorArray : public NS::Referencing +{ +public: + static BufferLayoutDescriptorArray* alloc(); + + BufferLayoutDescriptorArray* init(); + + BufferLayoutDescriptor* object(NS::UInteger index); + void setObject(const MTL::BufferLayoutDescriptor* bufferDesc, NS::UInteger index); +}; +class AttributeDescriptor : public NS::Copying +{ +public: + static AttributeDescriptor* alloc(); + + NS::UInteger bufferIndex() const; + + AttributeFormat format() const; + + AttributeDescriptor* init(); + + NS::UInteger offset() const; + + void setBufferIndex(NS::UInteger bufferIndex); + + void setFormat(MTL::AttributeFormat format); + + void setOffset(NS::UInteger offset); +}; +class AttributeDescriptorArray : public NS::Referencing +{ +public: + static AttributeDescriptorArray* alloc(); + + AttributeDescriptorArray* init(); + + AttributeDescriptor* object(NS::UInteger index); + void setObject(const MTL::AttributeDescriptor* attributeDesc, NS::UInteger index); +}; +class StageInputOutputDescriptor : public NS::Copying +{ +public: + static StageInputOutputDescriptor* alloc(); + + AttributeDescriptorArray* attributes() const; + + NS::UInteger indexBufferIndex() const; + + IndexType indexType() const; + + StageInputOutputDescriptor* init(); + + BufferLayoutDescriptorArray* layouts() const; + + void reset(); + + void setIndexBufferIndex(NS::UInteger indexBufferIndex); + + void setIndexType(MTL::IndexType indexType); + + static StageInputOutputDescriptor* stageInputOutputDescriptor(); +}; + +} +_MTL_INLINE MTL::BufferLayoutDescriptor* MTL::BufferLayoutDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBufferLayoutDescriptor)); +} + +_MTL_INLINE MTL::BufferLayoutDescriptor* MTL::BufferLayoutDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptor::setStepFunction(MTL::StepFunction stepFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepFunction_), stepFunction); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptor::setStepRate(NS::UInteger stepRate) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepRate_), stepRate); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptor::setStride(NS::UInteger stride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStride_), stride); +} + +_MTL_INLINE MTL::StepFunction MTL::BufferLayoutDescriptor::stepFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepFunction)); +} + +_MTL_INLINE NS::UInteger MTL::BufferLayoutDescriptor::stepRate() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepRate)); +} + +_MTL_INLINE NS::UInteger MTL::BufferLayoutDescriptor::stride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stride)); +} + +_MTL_INLINE MTL::BufferLayoutDescriptorArray* MTL::BufferLayoutDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLBufferLayoutDescriptorArray)); +} + +_MTL_INLINE MTL::BufferLayoutDescriptorArray* MTL::BufferLayoutDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BufferLayoutDescriptor* MTL::BufferLayoutDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::BufferLayoutDescriptorArray::setObject(const MTL::BufferLayoutDescriptor* bufferDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), bufferDesc, index); +} + +_MTL_INLINE MTL::AttributeDescriptor* MTL::AttributeDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAttributeDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::AttributeDescriptor::bufferIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferIndex)); +} + +_MTL_INLINE MTL::AttributeFormat MTL::AttributeDescriptor::format() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(format)); +} + +_MTL_INLINE MTL::AttributeDescriptor* MTL::AttributeDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::AttributeDescriptor::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE void MTL::AttributeDescriptor::setBufferIndex(NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferIndex_), bufferIndex); +} + +_MTL_INLINE void MTL::AttributeDescriptor::setFormat(MTL::AttributeFormat format) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFormat_), format); +} + +_MTL_INLINE void MTL::AttributeDescriptor::setOffset(NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOffset_), offset); +} + +_MTL_INLINE MTL::AttributeDescriptorArray* MTL::AttributeDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLAttributeDescriptorArray)); +} + +_MTL_INLINE MTL::AttributeDescriptorArray* MTL::AttributeDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::AttributeDescriptor* MTL::AttributeDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::AttributeDescriptorArray::setObject(const MTL::AttributeDescriptor* attributeDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attributeDesc, index); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::StageInputOutputDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLStageInputOutputDescriptor)); +} + +_MTL_INLINE MTL::AttributeDescriptorArray* MTL::StageInputOutputDescriptor::attributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributes)); +} + +_MTL_INLINE NS::UInteger MTL::StageInputOutputDescriptor::indexBufferIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexBufferIndex)); +} + +_MTL_INLINE MTL::IndexType MTL::StageInputOutputDescriptor::indexType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(indexType)); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::StageInputOutputDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::BufferLayoutDescriptorArray* MTL::StageInputOutputDescriptor::layouts() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layouts)); +} + +_MTL_INLINE void MTL::StageInputOutputDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE void MTL::StageInputOutputDescriptor::setIndexBufferIndex(NS::UInteger indexBufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexBufferIndex_), indexBufferIndex); +} + +_MTL_INLINE void MTL::StageInputOutputDescriptor::setIndexType(MTL::IndexType indexType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setIndexType_), indexType); +} + +_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::StageInputOutputDescriptor::stageInputOutputDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLStageInputOutputDescriptor), _MTL_PRIVATE_SEL(stageInputOutputDescriptor)); +} diff --git a/dist/include/metal_cpp/Metal/MTLTensor.hpp b/dist/include/metal_cpp/Metal/MTLTensor.hpp new file mode 100644 index 0000000..221b8c9 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLTensor.hpp @@ -0,0 +1,297 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTensor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class TensorDescriptor; +class TensorExtents; + +_MTL_CONST(NS::ErrorDomain, TensorDomain); + +_MTL_ENUM(NS::Integer, TensorDataType) { + TensorDataTypeNone = 0, + TensorDataTypeFloat32 = 3, + TensorDataTypeFloat16 = 16, + TensorDataTypeBFloat16 = 121, + TensorDataTypeInt8 = 45, + TensorDataTypeUInt8 = 49, + TensorDataTypeInt16 = 37, + TensorDataTypeUInt16 = 41, + TensorDataTypeInt32 = 29, + TensorDataTypeUInt32 = 33, +}; + +_MTL_ENUM(NS::Integer, TensorError) { + TensorErrorNone = 0, + TensorErrorInternalError = 1, + TensorErrorInvalidDescriptor = 2, +}; + +_MTL_OPTIONS(NS::UInteger, TensorUsage) { + TensorUsageCompute = 1, + TensorUsageRender = 1 << 1, + TensorUsageMachineLearning = 1 << 2, +}; + +class TensorExtents : public NS::Referencing +{ +public: + static TensorExtents* alloc(); + + NS::Integer extentAtDimensionIndex(NS::UInteger dimensionIndex); + + TensorExtents* init(); + TensorExtents* init(NS::UInteger rank, const NS::Integer* values); + + NS::UInteger rank() const; +}; +class TensorDescriptor : public NS::Copying +{ +public: + static TensorDescriptor* alloc(); + + CPUCacheMode cpuCacheMode() const; + + TensorDataType dataType() const; + + TensorExtents* dimensions() const; + + HazardTrackingMode hazardTrackingMode() const; + + TensorDescriptor* init(); + + ResourceOptions resourceOptions() const; + + void setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode); + + void setDataType(MTL::TensorDataType dataType); + + void setDimensions(const MTL::TensorExtents* dimensions); + + void setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode); + + void setResourceOptions(MTL::ResourceOptions resourceOptions); + + void setStorageMode(MTL::StorageMode storageMode); + + void setStrides(const MTL::TensorExtents* strides); + + void setUsage(MTL::TensorUsage usage); + + StorageMode storageMode() const; + + TensorExtents* strides() const; + + TensorUsage usage() const; +}; +class Tensor : public NS::Referencing +{ +public: + Buffer* buffer() const; + NS::UInteger bufferOffset() const; + + TensorDataType dataType() const; + + TensorExtents* dimensions() const; + + void getBytes(void* bytes, const MTL::TensorExtents* strides, const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions); + + ResourceID gpuResourceID() const; + + void replaceSliceOrigin(const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions, const void* bytes, const MTL::TensorExtents* strides); + + TensorExtents* strides() const; + + TensorUsage usage() const; +}; + +} + +_MTL_PRIVATE_DEF_CONST(NS::ErrorDomain, TensorDomain); + +_MTL_INLINE MTL::TensorExtents* MTL::TensorExtents::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTensorExtents)); +} + +_MTL_INLINE NS::Integer MTL::TensorExtents::extentAtDimensionIndex(NS::UInteger dimensionIndex) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(extentAtDimensionIndex_), dimensionIndex); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorExtents::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorExtents::init(NS::UInteger rank, const NS::Integer* values) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(initWithRank_values_), rank, values); +} + +_MTL_INLINE NS::UInteger MTL::TensorExtents::rank() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rank)); +} + +_MTL_INLINE MTL::TensorDescriptor* MTL::TensorDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTensorDescriptor)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::TensorDescriptor::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE MTL::TensorDataType MTL::TensorDescriptor::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorDescriptor::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::TensorDescriptor::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE MTL::TensorDescriptor* MTL::TensorDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::ResourceOptions MTL::TensorDescriptor::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE void MTL::TensorDescriptor::setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCpuCacheMode_), cpuCacheMode); +} + +_MTL_INLINE void MTL::TensorDescriptor::setDataType(MTL::TensorDataType dataType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDataType_), dataType); +} + +_MTL_INLINE void MTL::TensorDescriptor::setDimensions(const MTL::TensorExtents* dimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDimensions_), dimensions); +} + +_MTL_INLINE void MTL::TensorDescriptor::setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHazardTrackingMode_), hazardTrackingMode); +} + +_MTL_INLINE void MTL::TensorDescriptor::setResourceOptions(MTL::ResourceOptions resourceOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceOptions_), resourceOptions); +} + +_MTL_INLINE void MTL::TensorDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE void MTL::TensorDescriptor::setStrides(const MTL::TensorExtents* strides) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStrides_), strides); +} + +_MTL_INLINE void MTL::TensorDescriptor::setUsage(MTL::TensorUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUsage_), usage); +} + +_MTL_INLINE MTL::StorageMode MTL::TensorDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::TensorDescriptor::strides() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(strides)); +} + +_MTL_INLINE MTL::TensorUsage MTL::TensorDescriptor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE MTL::Buffer* MTL::Tensor::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE NS::UInteger MTL::Tensor::bufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferOffset)); +} + +_MTL_INLINE MTL::TensorDataType MTL::Tensor::dataType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dataType)); +} + +_MTL_INLINE MTL::TensorExtents* MTL::Tensor::dimensions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(dimensions)); +} + +_MTL_INLINE void MTL::Tensor::getBytes(void* bytes, const MTL::TensorExtents* strides, const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getBytes_strides_fromSliceOrigin_sliceDimensions_), bytes, strides, sliceOrigin, sliceDimensions); +} + +_MTL_INLINE MTL::ResourceID MTL::Tensor::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE void MTL::Tensor::replaceSliceOrigin(const MTL::TensorExtents* sliceOrigin, const MTL::TensorExtents* sliceDimensions, const void* bytes, const MTL::TensorExtents* strides) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(replaceSliceOrigin_sliceDimensions_withBytes_strides_), sliceOrigin, sliceDimensions, bytes, strides); +} + +_MTL_INLINE MTL::TensorExtents* MTL::Tensor::strides() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(strides)); +} + +_MTL_INLINE MTL::TensorUsage MTL::Tensor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} diff --git a/dist/include/metal_cpp/Metal/MTLTexture.hpp b/dist/include/metal_cpp/Metal/MTLTexture.hpp new file mode 100644 index 0000000..631d920 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLTexture.hpp @@ -0,0 +1,803 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTexture.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" +#include + +namespace MTL +{ +class Buffer; +class Device; +class Resource; +class SharedTextureHandle; +class Texture; +class TextureDescriptor; +class TextureViewDescriptor; +} + +namespace MTL +{ +_MTL_ENUM(NS::UInteger, TextureType) { + TextureType1D = 0, + TextureType1DArray = 1, + TextureType2D = 2, + TextureType2DArray = 3, + TextureType2DMultisample = 4, + TextureTypeCube = 5, + TextureTypeCubeArray = 6, + TextureType3D = 7, + TextureType2DMultisampleArray = 8, + TextureTypeTextureBuffer = 9, +}; + +_MTL_ENUM(uint8_t, TextureSwizzle) { + TextureSwizzleZero = 0, + TextureSwizzleOne = 1, + TextureSwizzleRed = 2, + TextureSwizzleGreen = 3, + TextureSwizzleBlue = 4, + TextureSwizzleAlpha = 5, +}; + +_MTL_ENUM(NS::Integer, TextureCompressionType) { + TextureCompressionTypeLossless = 0, + TextureCompressionTypeLossy = 1, +}; + +_MTL_OPTIONS(NS::UInteger, TextureUsage) { + TextureUsageUnknown = 0, + TextureUsageShaderRead = 1, + TextureUsageShaderWrite = 1 << 1, + TextureUsageRenderTarget = 1 << 2, + TextureUsagePixelFormatView = 1 << 4, + TextureUsageShaderAtomic = 1 << 5, +}; + +struct TextureSwizzleChannels +{ + + TextureSwizzleChannels(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a); + + TextureSwizzleChannels(); + + static TextureSwizzleChannels Default(); + + static TextureSwizzleChannels Make(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a); + + MTL::TextureSwizzle red; + MTL::TextureSwizzle green; + MTL::TextureSwizzle blue; + MTL::TextureSwizzle alpha; +} _MTL_PACKED; + +class SharedTextureHandle : public NS::SecureCoding +{ +public: + static SharedTextureHandle* alloc(); + + Device* device() const; + + SharedTextureHandle* init(); + + NS::String* label() const; +}; +class TextureDescriptor : public NS::Copying +{ +public: + static TextureDescriptor* alloc(); + + bool allowGPUOptimizedContents() const; + + NS::UInteger arrayLength() const; + + TextureCompressionType compressionType() const; + + CPUCacheMode cpuCacheMode() const; + + NS::UInteger depth() const; + + HazardTrackingMode hazardTrackingMode() const; + + NS::UInteger height() const; + + TextureDescriptor* init(); + + NS::UInteger mipmapLevelCount() const; + + PixelFormat pixelFormat() const; + + SparsePageSize placementSparsePageSize() const; + + ResourceOptions resourceOptions() const; + + NS::UInteger sampleCount() const; + + void setAllowGPUOptimizedContents(bool allowGPUOptimizedContents); + + void setArrayLength(NS::UInteger arrayLength); + + void setCompressionType(MTL::TextureCompressionType compressionType); + + void setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode); + + void setDepth(NS::UInteger depth); + + void setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode); + + void setHeight(NS::UInteger height); + + void setMipmapLevelCount(NS::UInteger mipmapLevelCount); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setPlacementSparsePageSize(MTL::SparsePageSize placementSparsePageSize); + + void setResourceOptions(MTL::ResourceOptions resourceOptions); + + void setSampleCount(NS::UInteger sampleCount); + + void setStorageMode(MTL::StorageMode storageMode); + + void setSwizzle(MTL::TextureSwizzleChannels swizzle); + + void setTextureType(MTL::TextureType textureType); + + void setUsage(MTL::TextureUsage usage); + + void setWidth(NS::UInteger width); + + StorageMode storageMode() const; + + TextureSwizzleChannels swizzle() const; + + static TextureDescriptor* texture2DDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, NS::UInteger height, bool mipmapped); + + static TextureDescriptor* textureBufferDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, MTL::ResourceOptions resourceOptions, MTL::TextureUsage usage); + + static TextureDescriptor* textureCubeDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger size, bool mipmapped); + + TextureType textureType() const; + + TextureUsage usage() const; + + NS::UInteger width() const; +}; +class TextureViewDescriptor : public NS::Copying +{ +public: + static TextureViewDescriptor* alloc(); + + TextureViewDescriptor* init(); + + NS::Range levelRange() const; + + PixelFormat pixelFormat() const; + + void setLevelRange(NS::Range levelRange); + + void setPixelFormat(MTL::PixelFormat pixelFormat); + + void setSliceRange(NS::Range sliceRange); + + void setSwizzle(MTL::TextureSwizzleChannels swizzle); + + void setTextureType(MTL::TextureType textureType); + + NS::Range sliceRange() const; + + TextureSwizzleChannels swizzle() const; + + TextureType textureType() const; +}; +class Texture : public NS::Referencing +{ +public: + bool allowGPUOptimizedContents() const; + + NS::UInteger arrayLength() const; + + Buffer* buffer() const; + NS::UInteger bufferBytesPerRow() const; + + NS::UInteger bufferOffset() const; + + TextureCompressionType compressionType() const; + + NS::UInteger depth() const; + + NS::UInteger firstMipmapInTail() const; + + [[deprecated("please use isFramebufferOnly instead")]] + bool framebufferOnly() const; + + void getBytes(void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage, MTL::Region region, NS::UInteger level, NS::UInteger slice); + void getBytes(void* pixelBytes, NS::UInteger bytesPerRow, MTL::Region region, NS::UInteger level); + + ResourceID gpuResourceID() const; + + NS::UInteger height() const; + + IOSurfaceRef iosurface() const; + NS::UInteger iosurfacePlane() const; + + bool isFramebufferOnly() const; + + bool isShareable() const; + + bool isSparse() const; + + NS::UInteger mipmapLevelCount() const; + + Texture* newRemoteTextureViewForDevice(const MTL::Device* device); + + SharedTextureHandle* newSharedTextureHandle(); + + Texture* newTextureView(MTL::PixelFormat pixelFormat); + Texture* newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange); + Texture* newTextureView(const MTL::TextureViewDescriptor* descriptor); + Texture* newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange, MTL::TextureSwizzleChannels swizzle); + + NS::UInteger parentRelativeLevel() const; + + NS::UInteger parentRelativeSlice() const; + + Texture* parentTexture() const; + + PixelFormat pixelFormat() const; + + Texture* remoteStorageTexture() const; + + void replaceRegion(MTL::Region region, NS::UInteger level, NS::UInteger slice, const void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage); + void replaceRegion(MTL::Region region, NS::UInteger level, const void* pixelBytes, NS::UInteger bytesPerRow); + + Resource* rootResource() const; + + NS::UInteger sampleCount() const; + + [[deprecated("please use isShareable instead")]] + bool shareable() const; + + TextureSparseTier sparseTextureTier() const; + + TextureSwizzleChannels swizzle() const; + + NS::UInteger tailSizeInBytes() const; + + TextureType textureType() const; + + TextureUsage usage() const; + + NS::UInteger width() const; +}; + +} +_MTL_INLINE MTL::TextureSwizzleChannels::TextureSwizzleChannels(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a) + : red(r) + , green(g) + , blue(b) + , alpha(a) +{ +} + +_MTL_INLINE MTL::TextureSwizzleChannels::TextureSwizzleChannels() + : red(MTL::TextureSwizzleRed) + , green(MTL::TextureSwizzleGreen) + , blue(MTL::TextureSwizzleBlue) + , alpha(MTL::TextureSwizzleAlpha) +{ +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureSwizzleChannels::Default() +{ + return MTL::TextureSwizzleChannels(); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureSwizzleChannels::Make(MTL::TextureSwizzle r, MTL::TextureSwizzle g, MTL::TextureSwizzle b, MTL::TextureSwizzle a) +{ + return TextureSwizzleChannels(r, g, b, a); +} + +_MTL_INLINE MTL::SharedTextureHandle* MTL::SharedTextureHandle::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLSharedTextureHandle)); +} + +_MTL_INLINE MTL::Device* MTL::SharedTextureHandle::device() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(device)); +} + +_MTL_INLINE MTL::SharedTextureHandle* MTL::SharedTextureHandle::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::String* MTL::SharedTextureHandle::label() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(label)); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTextureDescriptor)); +} + +_MTL_INLINE bool MTL::TextureDescriptor::allowGPUOptimizedContents() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowGPUOptimizedContents)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE MTL::TextureCompressionType MTL::TextureDescriptor::compressionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compressionType)); +} + +_MTL_INLINE MTL::CPUCacheMode MTL::TextureDescriptor::cpuCacheMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(cpuCacheMode)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::depth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depth)); +} + +_MTL_INLINE MTL::HazardTrackingMode MTL::TextureDescriptor::hazardTrackingMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(hazardTrackingMode)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::height() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(height)); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::mipmapLevelCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mipmapLevelCount)); +} + +_MTL_INLINE MTL::PixelFormat MTL::TextureDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE MTL::SparsePageSize MTL::TextureDescriptor::placementSparsePageSize() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(placementSparsePageSize)); +} + +_MTL_INLINE MTL::ResourceOptions MTL::TextureDescriptor::resourceOptions() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(resourceOptions)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE void MTL::TextureDescriptor::setAllowGPUOptimizedContents(bool allowGPUOptimizedContents) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setAllowGPUOptimizedContents_), allowGPUOptimizedContents); +} + +_MTL_INLINE void MTL::TextureDescriptor::setArrayLength(NS::UInteger arrayLength) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setArrayLength_), arrayLength); +} + +_MTL_INLINE void MTL::TextureDescriptor::setCompressionType(MTL::TextureCompressionType compressionType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCompressionType_), compressionType); +} + +_MTL_INLINE void MTL::TextureDescriptor::setCpuCacheMode(MTL::CPUCacheMode cpuCacheMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setCpuCacheMode_), cpuCacheMode); +} + +_MTL_INLINE void MTL::TextureDescriptor::setDepth(NS::UInteger depth) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setDepth_), depth); +} + +_MTL_INLINE void MTL::TextureDescriptor::setHazardTrackingMode(MTL::HazardTrackingMode hazardTrackingMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHazardTrackingMode_), hazardTrackingMode); +} + +_MTL_INLINE void MTL::TextureDescriptor::setHeight(NS::UInteger height) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setHeight_), height); +} + +_MTL_INLINE void MTL::TextureDescriptor::setMipmapLevelCount(NS::UInteger mipmapLevelCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setMipmapLevelCount_), mipmapLevelCount); +} + +_MTL_INLINE void MTL::TextureDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL::TextureDescriptor::setPlacementSparsePageSize(MTL::SparsePageSize placementSparsePageSize) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPlacementSparsePageSize_), placementSparsePageSize); +} + +_MTL_INLINE void MTL::TextureDescriptor::setResourceOptions(MTL::ResourceOptions resourceOptions) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setResourceOptions_), resourceOptions); +} + +_MTL_INLINE void MTL::TextureDescriptor::setSampleCount(NS::UInteger sampleCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSampleCount_), sampleCount); +} + +_MTL_INLINE void MTL::TextureDescriptor::setStorageMode(MTL::StorageMode storageMode) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStorageMode_), storageMode); +} + +_MTL_INLINE void MTL::TextureDescriptor::setSwizzle(MTL::TextureSwizzleChannels swizzle) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSwizzle_), swizzle); +} + +_MTL_INLINE void MTL::TextureDescriptor::setTextureType(MTL::TextureType textureType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureType_), textureType); +} + +_MTL_INLINE void MTL::TextureDescriptor::setUsage(MTL::TextureUsage usage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setUsage_), usage); +} + +_MTL_INLINE void MTL::TextureDescriptor::setWidth(NS::UInteger width) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setWidth_), width); +} + +_MTL_INLINE MTL::StorageMode MTL::TextureDescriptor::storageMode() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(storageMode)); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureDescriptor::swizzle() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(swizzle)); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::texture2DDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, NS::UInteger height, bool mipmapped) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLTextureDescriptor), _MTL_PRIVATE_SEL(texture2DDescriptorWithPixelFormat_width_height_mipmapped_), pixelFormat, width, height, mipmapped); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::textureBufferDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger width, MTL::ResourceOptions resourceOptions, MTL::TextureUsage usage) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLTextureDescriptor), _MTL_PRIVATE_SEL(textureBufferDescriptorWithPixelFormat_width_resourceOptions_usage_), pixelFormat, width, resourceOptions, usage); +} + +_MTL_INLINE MTL::TextureDescriptor* MTL::TextureDescriptor::textureCubeDescriptor(MTL::PixelFormat pixelFormat, NS::UInteger size, bool mipmapped) +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLTextureDescriptor), _MTL_PRIVATE_SEL(textureCubeDescriptorWithPixelFormat_size_mipmapped_), pixelFormat, size, mipmapped); +} + +_MTL_INLINE MTL::TextureType MTL::TextureDescriptor::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::TextureUsage MTL::TextureDescriptor::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE NS::UInteger MTL::TextureDescriptor::width() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(width)); +} + +_MTL_INLINE MTL::TextureViewDescriptor* MTL::TextureViewDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLTextureViewDescriptor)); +} + +_MTL_INLINE MTL::TextureViewDescriptor* MTL::TextureViewDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::Range MTL::TextureViewDescriptor::levelRange() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(levelRange)); +} + +_MTL_INLINE MTL::PixelFormat MTL::TextureViewDescriptor::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setLevelRange(NS::Range levelRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setLevelRange_), levelRange); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setPixelFormat_), pixelFormat); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setSliceRange(NS::Range sliceRange) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSliceRange_), sliceRange); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setSwizzle(MTL::TextureSwizzleChannels swizzle) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setSwizzle_), swizzle); +} + +_MTL_INLINE void MTL::TextureViewDescriptor::setTextureType(MTL::TextureType textureType) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureType_), textureType); +} + +_MTL_INLINE NS::Range MTL::TextureViewDescriptor::sliceRange() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sliceRange)); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::TextureViewDescriptor::swizzle() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(swizzle)); +} + +_MTL_INLINE MTL::TextureType MTL::TextureViewDescriptor::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE bool MTL::Texture::allowGPUOptimizedContents() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(allowGPUOptimizedContents)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::arrayLength() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(arrayLength)); +} + +_MTL_INLINE MTL::Buffer* MTL::Texture::buffer() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(buffer)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::bufferBytesPerRow() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferBytesPerRow)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::bufferOffset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferOffset)); +} + +_MTL_INLINE MTL::TextureCompressionType MTL::Texture::compressionType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(compressionType)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::depth() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(depth)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::firstMipmapInTail() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(firstMipmapInTail)); +} + +_MTL_INLINE bool MTL::Texture::framebufferOnly() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isFramebufferOnly)); +} + +_MTL_INLINE void MTL::Texture::getBytes(void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage, MTL::Region region, NS::UInteger level, NS::UInteger slice) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getBytes_bytesPerRow_bytesPerImage_fromRegion_mipmapLevel_slice_), pixelBytes, bytesPerRow, bytesPerImage, region, level, slice); +} + +_MTL_INLINE void MTL::Texture::getBytes(void* pixelBytes, NS::UInteger bytesPerRow, MTL::Region region, NS::UInteger level) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(getBytes_bytesPerRow_fromRegion_mipmapLevel_), pixelBytes, bytesPerRow, region, level); +} + +_MTL_INLINE MTL::ResourceID MTL::Texture::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::height() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(height)); +} + +_MTL_INLINE IOSurfaceRef MTL::Texture::iosurface() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(iosurface)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::iosurfacePlane() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(iosurfacePlane)); +} + +_MTL_INLINE bool MTL::Texture::isFramebufferOnly() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isFramebufferOnly)); +} + +_MTL_INLINE bool MTL::Texture::isShareable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isShareable)); +} + +_MTL_INLINE bool MTL::Texture::isSparse() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isSparse)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::mipmapLevelCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(mipmapLevelCount)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newRemoteTextureViewForDevice(const MTL::Device* device) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newRemoteTextureViewForDevice_), device); +} + +_MTL_INLINE MTL::SharedTextureHandle* MTL::Texture::newSharedTextureHandle() +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newSharedTextureHandle)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(MTL::PixelFormat pixelFormat) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithPixelFormat_), pixelFormat); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_), pixelFormat, textureType, levelRange, sliceRange); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(const MTL::TextureViewDescriptor* descriptor) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithDescriptor_), descriptor); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::newTextureView(MTL::PixelFormat pixelFormat, MTL::TextureType textureType, NS::Range levelRange, NS::Range sliceRange, MTL::TextureSwizzleChannels swizzle) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(newTextureViewWithPixelFormat_textureType_levels_slices_swizzle_), pixelFormat, textureType, levelRange, sliceRange, swizzle); +} + +_MTL_INLINE NS::UInteger MTL::Texture::parentRelativeLevel() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parentRelativeLevel)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::parentRelativeSlice() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parentRelativeSlice)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::parentTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(parentTexture)); +} + +_MTL_INLINE MTL::PixelFormat MTL::Texture::pixelFormat() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(pixelFormat)); +} + +_MTL_INLINE MTL::Texture* MTL::Texture::remoteStorageTexture() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(remoteStorageTexture)); +} + +_MTL_INLINE void MTL::Texture::replaceRegion(MTL::Region region, NS::UInteger level, NS::UInteger slice, const void* pixelBytes, NS::UInteger bytesPerRow, NS::UInteger bytesPerImage) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(replaceRegion_mipmapLevel_slice_withBytes_bytesPerRow_bytesPerImage_), region, level, slice, pixelBytes, bytesPerRow, bytesPerImage); +} + +_MTL_INLINE void MTL::Texture::replaceRegion(MTL::Region region, NS::UInteger level, const void* pixelBytes, NS::UInteger bytesPerRow) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(replaceRegion_mipmapLevel_withBytes_bytesPerRow_), region, level, pixelBytes, bytesPerRow); +} + +_MTL_INLINE MTL::Resource* MTL::Texture::rootResource() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(rootResource)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::sampleCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sampleCount)); +} + +_MTL_INLINE bool MTL::Texture::shareable() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(isShareable)); +} + +_MTL_INLINE MTL::TextureSparseTier MTL::Texture::sparseTextureTier() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(sparseTextureTier)); +} + +_MTL_INLINE MTL::TextureSwizzleChannels MTL::Texture::swizzle() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(swizzle)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::tailSizeInBytes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(tailSizeInBytes)); +} + +_MTL_INLINE MTL::TextureType MTL::Texture::textureType() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(textureType)); +} + +_MTL_INLINE MTL::TextureUsage MTL::Texture::usage() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(usage)); +} + +_MTL_INLINE NS::UInteger MTL::Texture::width() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(width)); +} diff --git a/dist/include/metal_cpp/Metal/MTLTextureViewPool.hpp b/dist/include/metal_cpp/Metal/MTLTextureViewPool.hpp new file mode 100644 index 0000000..cb7556f --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLTextureViewPool.hpp @@ -0,0 +1,59 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTextureViewPool.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResourceViewPool.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class Buffer; +class Texture; +class TextureDescriptor; +class TextureViewDescriptor; + +class TextureViewPool : public NS::Referencing +{ +public: + ResourceID setTextureView(const MTL::Texture* texture, NS::UInteger index); + ResourceID setTextureView(const MTL::Texture* texture, const MTL::TextureViewDescriptor* descriptor, NS::UInteger index); + ResourceID setTextureViewFromBuffer(const MTL::Buffer* buffer, const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow, NS::UInteger index); +}; + +} +_MTL_INLINE MTL::ResourceID MTL::TextureViewPool::setTextureView(const MTL::Texture* texture, NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureView_atIndex_), texture, index); +} + +_MTL_INLINE MTL::ResourceID MTL::TextureViewPool::setTextureView(const MTL::Texture* texture, const MTL::TextureViewDescriptor* descriptor, NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureView_descriptor_atIndex_), texture, descriptor, index); +} + +_MTL_INLINE MTL::ResourceID MTL::TextureViewPool::setTextureViewFromBuffer(const MTL::Buffer* buffer, const MTL::TextureDescriptor* descriptor, NS::UInteger offset, NS::UInteger bytesPerRow, NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(setTextureViewFromBuffer_descriptor_offset_bytesPerRow_atIndex_), buffer, descriptor, offset, bytesPerRow, index); +} diff --git a/dist/include/metal_cpp/Metal/MTLTypes.hpp b/dist/include/metal_cpp/Metal/MTLTypes.hpp new file mode 100644 index 0000000..c6bbc03 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLTypes.hpp @@ -0,0 +1,164 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLTypes.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +struct SamplePosition; + +using Coordinate2D = MTL::SamplePosition; + +struct Origin +{ + Origin() = default; + + Origin(NS::UInteger x, NS::UInteger y, NS::UInteger z); + + static Origin Make(NS::UInteger x, NS::UInteger y, NS::UInteger z); + + NS::UInteger x; + NS::UInteger y; + NS::UInteger z; +} _MTL_PACKED; + +struct Size +{ + Size() = default; + + Size(NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + static Size Make(NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + NS::UInteger width; + NS::UInteger height; + NS::UInteger depth; +} _MTL_PACKED; + +struct Region +{ + Region() = default; + + Region(NS::UInteger x, NS::UInteger width); + + Region(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height); + + Region(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + static Region Make1D(NS::UInteger x, NS::UInteger width); + + static Region Make2D(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height); + + static Region Make3D(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth); + + MTL::Origin origin; + MTL::Size size; +} _MTL_PACKED; + +struct SamplePosition +{ + SamplePosition() = default; + + SamplePosition(float x, float y); + + static SamplePosition Make(float x, float y); + + float x; + float y; +} _MTL_PACKED; + +struct ResourceID +{ + uint64_t _impl; +} _MTL_PACKED; + +} +_MTL_INLINE MTL::Origin::Origin(NS::UInteger x, NS::UInteger y, NS::UInteger z) + : x(x) + , y(y) + , z(z) +{ +} + +_MTL_INLINE MTL::Origin MTL::Origin::Make(NS::UInteger x, NS::UInteger y, NS::UInteger z) +{ + return Origin(x, y, z); +} + +_MTL_INLINE MTL::Size::Size(NS::UInteger width, NS::UInteger height, NS::UInteger depth) + : width(width) + , height(height) + , depth(depth) +{ +} + +_MTL_INLINE MTL::Size MTL::Size::Make(NS::UInteger width, NS::UInteger height, NS::UInteger depth) +{ + return Size(width, height, depth); +} + +_MTL_INLINE MTL::Region::Region(NS::UInteger x, NS::UInteger width) + : origin(x, 0, 0) + , size(width, 1, 1) +{ +} + +_MTL_INLINE MTL::Region::Region(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height) + : origin(x, y, 0) + , size(width, height, 1) +{ +} + +_MTL_INLINE MTL::Region::Region(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth) + : origin(x, y, z) + , size(width, height, depth) +{ +} + +_MTL_INLINE MTL::Region MTL::Region::Make1D(NS::UInteger x, NS::UInteger width) +{ + return Region(x, width); +} + +_MTL_INLINE MTL::Region MTL::Region::Make2D(NS::UInteger x, NS::UInteger y, NS::UInteger width, NS::UInteger height) +{ + return Region(x, y, width, height); +} + +_MTL_INLINE MTL::Region MTL::Region::Make3D(NS::UInteger x, NS::UInteger y, NS::UInteger z, NS::UInteger width, NS::UInteger height, NS::UInteger depth) +{ + return Region(x, y, z, width, height, depth); +} + +_MTL_INLINE MTL::SamplePosition::SamplePosition(float x, float y) + : x(x) + , y(y) +{ +} + +_MTL_INLINE MTL::SamplePosition MTL::SamplePosition::Make(float x, float y) +{ + return SamplePosition(x, y); +} diff --git a/dist/include/metal_cpp/Metal/MTLVersion.hpp b/dist/include/metal_cpp/Metal/MTLVersion.hpp new file mode 100644 index 0000000..d350397 --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLVersion.hpp @@ -0,0 +1,32 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLVersion.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define METALCPP_VERSION_MAJOR 370 +#define METALCPP_VERSION_MINOR 63 +#define METALCPP_VERSION_PATCH 1 + +#define METALCPP_SUPPORTS_VERSION(major, minor, patch) \ + ((major < METALCPP_VERSION_MAJOR) || \ + (major == METALCPP_VERSION_MAJOR && minor < METALCPP_VERSION_MINOR) || \ + (major == METALCPP_VERSION_MAJOR && minor == METALCPP_VERSION_MINOR && patch <= METALCPP_VERSION_PATCH)) diff --git a/dist/include/metal_cpp/Metal/MTLVertexDescriptor.hpp b/dist/include/metal_cpp/Metal/MTLVertexDescriptor.hpp new file mode 100644 index 0000000..4a38f3b --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLVertexDescriptor.hpp @@ -0,0 +1,326 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLVertexDescriptor.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" + +namespace MTL +{ +class VertexAttributeDescriptor; +class VertexAttributeDescriptorArray; +class VertexBufferLayoutDescriptor; +class VertexBufferLayoutDescriptorArray; +class VertexDescriptor; +_MTL_ENUM(NS::UInteger, VertexFormat) { + VertexFormatInvalid = 0, + VertexFormatUChar2 = 1, + VertexFormatUChar3 = 2, + VertexFormatUChar4 = 3, + VertexFormatChar2 = 4, + VertexFormatChar3 = 5, + VertexFormatChar4 = 6, + VertexFormatUChar2Normalized = 7, + VertexFormatUChar3Normalized = 8, + VertexFormatUChar4Normalized = 9, + VertexFormatChar2Normalized = 10, + VertexFormatChar3Normalized = 11, + VertexFormatChar4Normalized = 12, + VertexFormatUShort2 = 13, + VertexFormatUShort3 = 14, + VertexFormatUShort4 = 15, + VertexFormatShort2 = 16, + VertexFormatShort3 = 17, + VertexFormatShort4 = 18, + VertexFormatUShort2Normalized = 19, + VertexFormatUShort3Normalized = 20, + VertexFormatUShort4Normalized = 21, + VertexFormatShort2Normalized = 22, + VertexFormatShort3Normalized = 23, + VertexFormatShort4Normalized = 24, + VertexFormatHalf2 = 25, + VertexFormatHalf3 = 26, + VertexFormatHalf4 = 27, + VertexFormatFloat = 28, + VertexFormatFloat2 = 29, + VertexFormatFloat3 = 30, + VertexFormatFloat4 = 31, + VertexFormatInt = 32, + VertexFormatInt2 = 33, + VertexFormatInt3 = 34, + VertexFormatInt4 = 35, + VertexFormatUInt = 36, + VertexFormatUInt2 = 37, + VertexFormatUInt3 = 38, + VertexFormatUInt4 = 39, + VertexFormatInt1010102Normalized = 40, + VertexFormatUInt1010102Normalized = 41, + VertexFormatUChar4Normalized_BGRA = 42, + VertexFormatUChar = 45, + VertexFormatChar = 46, + VertexFormatUCharNormalized = 47, + VertexFormatCharNormalized = 48, + VertexFormatUShort = 49, + VertexFormatShort = 50, + VertexFormatUShortNormalized = 51, + VertexFormatShortNormalized = 52, + VertexFormatHalf = 53, + VertexFormatFloatRG11B10 = 54, + VertexFormatFloatRGB9E5 = 55, +}; + +_MTL_ENUM(NS::UInteger, VertexStepFunction) { + VertexStepFunctionConstant = 0, + VertexStepFunctionPerVertex = 1, + VertexStepFunctionPerInstance = 2, + VertexStepFunctionPerPatch = 3, + VertexStepFunctionPerPatchControlPoint = 4, +}; + +static const NS::UInteger BufferLayoutStrideDynamic = NS::UIntegerMax; + +class VertexBufferLayoutDescriptor : public NS::Copying +{ +public: + static VertexBufferLayoutDescriptor* alloc(); + + VertexBufferLayoutDescriptor* init(); + + void setStepFunction(MTL::VertexStepFunction stepFunction); + + void setStepRate(NS::UInteger stepRate); + + void setStride(NS::UInteger stride); + + VertexStepFunction stepFunction() const; + + NS::UInteger stepRate() const; + + NS::UInteger stride() const; +}; +class VertexBufferLayoutDescriptorArray : public NS::Referencing +{ +public: + static VertexBufferLayoutDescriptorArray* alloc(); + + VertexBufferLayoutDescriptorArray* init(); + + VertexBufferLayoutDescriptor* object(NS::UInteger index); + void setObject(const MTL::VertexBufferLayoutDescriptor* bufferDesc, NS::UInteger index); +}; +class VertexAttributeDescriptor : public NS::Copying +{ +public: + static VertexAttributeDescriptor* alloc(); + + NS::UInteger bufferIndex() const; + + VertexFormat format() const; + + VertexAttributeDescriptor* init(); + + NS::UInteger offset() const; + + void setBufferIndex(NS::UInteger bufferIndex); + + void setFormat(MTL::VertexFormat format); + + void setOffset(NS::UInteger offset); +}; +class VertexAttributeDescriptorArray : public NS::Referencing +{ +public: + static VertexAttributeDescriptorArray* alloc(); + + VertexAttributeDescriptorArray* init(); + + VertexAttributeDescriptor* object(NS::UInteger index); + void setObject(const MTL::VertexAttributeDescriptor* attributeDesc, NS::UInteger index); +}; +class VertexDescriptor : public NS::Copying +{ +public: + static VertexDescriptor* alloc(); + + VertexAttributeDescriptorArray* attributes() const; + + VertexDescriptor* init(); + + VertexBufferLayoutDescriptorArray* layouts() const; + + void reset(); + + static VertexDescriptor* vertexDescriptor(); +}; + +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptor* MTL::VertexBufferLayoutDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexBufferLayoutDescriptor)); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptor* MTL::VertexBufferLayoutDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptor::setStepFunction(MTL::VertexStepFunction stepFunction) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepFunction_), stepFunction); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptor::setStepRate(NS::UInteger stepRate) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStepRate_), stepRate); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptor::setStride(NS::UInteger stride) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setStride_), stride); +} + +_MTL_INLINE MTL::VertexStepFunction MTL::VertexBufferLayoutDescriptor::stepFunction() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepFunction)); +} + +_MTL_INLINE NS::UInteger MTL::VertexBufferLayoutDescriptor::stepRate() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stepRate)); +} + +_MTL_INLINE NS::UInteger MTL::VertexBufferLayoutDescriptor::stride() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(stride)); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptorArray* MTL::VertexBufferLayoutDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexBufferLayoutDescriptorArray)); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptorArray* MTL::VertexBufferLayoutDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptor* MTL::VertexBufferLayoutDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::VertexBufferLayoutDescriptorArray::setObject(const MTL::VertexBufferLayoutDescriptor* bufferDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), bufferDesc, index); +} + +_MTL_INLINE MTL::VertexAttributeDescriptor* MTL::VertexAttributeDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexAttributeDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::VertexAttributeDescriptor::bufferIndex() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(bufferIndex)); +} + +_MTL_INLINE MTL::VertexFormat MTL::VertexAttributeDescriptor::format() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(format)); +} + +_MTL_INLINE MTL::VertexAttributeDescriptor* MTL::VertexAttributeDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE NS::UInteger MTL::VertexAttributeDescriptor::offset() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(offset)); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptor::setBufferIndex(NS::UInteger bufferIndex) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setBufferIndex_), bufferIndex); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptor::setFormat(MTL::VertexFormat format) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFormat_), format); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptor::setOffset(NS::UInteger offset) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setOffset_), offset); +} + +_MTL_INLINE MTL::VertexAttributeDescriptorArray* MTL::VertexAttributeDescriptorArray::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexAttributeDescriptorArray)); +} + +_MTL_INLINE MTL::VertexAttributeDescriptorArray* MTL::VertexAttributeDescriptorArray::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::VertexAttributeDescriptor* MTL::VertexAttributeDescriptorArray::object(NS::UInteger index) +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(objectAtIndexedSubscript_), index); +} + +_MTL_INLINE void MTL::VertexAttributeDescriptorArray::setObject(const MTL::VertexAttributeDescriptor* attributeDesc, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setObject_atIndexedSubscript_), attributeDesc, index); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::VertexDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVertexDescriptor)); +} + +_MTL_INLINE MTL::VertexAttributeDescriptorArray* MTL::VertexDescriptor::attributes() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(attributes)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::VertexDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE MTL::VertexBufferLayoutDescriptorArray* MTL::VertexDescriptor::layouts() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(layouts)); +} + +_MTL_INLINE void MTL::VertexDescriptor::reset() +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(reset)); +} + +_MTL_INLINE MTL::VertexDescriptor* MTL::VertexDescriptor::vertexDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLVertexDescriptor), _MTL_PRIVATE_SEL(vertexDescriptor)); +} diff --git a/dist/include/metal_cpp/Metal/MTLVisibleFunctionTable.hpp b/dist/include/metal_cpp/Metal/MTLVisibleFunctionTable.hpp new file mode 100644 index 0000000..de144ea --- /dev/null +++ b/dist/include/metal_cpp/Metal/MTLVisibleFunctionTable.hpp @@ -0,0 +1,96 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/MTLVisibleFunctionTable.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +#include "../Foundation/Foundation.hpp" +#include "MTLDefines.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLPrivate.hpp" +#include "MTLResource.hpp" +#include "MTLTypes.hpp" + +namespace MTL +{ +class FunctionHandle; +class VisibleFunctionTableDescriptor; + +class VisibleFunctionTableDescriptor : public NS::Copying +{ +public: + static VisibleFunctionTableDescriptor* alloc(); + + NS::UInteger functionCount() const; + + VisibleFunctionTableDescriptor* init(); + + void setFunctionCount(NS::UInteger functionCount); + + static VisibleFunctionTableDescriptor* visibleFunctionTableDescriptor(); +}; +class VisibleFunctionTable : public NS::Referencing +{ +public: + ResourceID gpuResourceID() const; + + void setFunction(const MTL::FunctionHandle* function, NS::UInteger index); + void setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range); +}; + +} +_MTL_INLINE MTL::VisibleFunctionTableDescriptor* MTL::VisibleFunctionTableDescriptor::alloc() +{ + return NS::Object::alloc(_MTL_PRIVATE_CLS(MTLVisibleFunctionTableDescriptor)); +} + +_MTL_INLINE NS::UInteger MTL::VisibleFunctionTableDescriptor::functionCount() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(functionCount)); +} + +_MTL_INLINE MTL::VisibleFunctionTableDescriptor* MTL::VisibleFunctionTableDescriptor::init() +{ + return NS::Object::init(); +} + +_MTL_INLINE void MTL::VisibleFunctionTableDescriptor::setFunctionCount(NS::UInteger functionCount) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctionCount_), functionCount); +} + +_MTL_INLINE MTL::VisibleFunctionTableDescriptor* MTL::VisibleFunctionTableDescriptor::visibleFunctionTableDescriptor() +{ + return Object::sendMessage(_MTL_PRIVATE_CLS(MTLVisibleFunctionTableDescriptor), _MTL_PRIVATE_SEL(visibleFunctionTableDescriptor)); +} + +_MTL_INLINE MTL::ResourceID MTL::VisibleFunctionTable::gpuResourceID() const +{ + return Object::sendMessage(this, _MTL_PRIVATE_SEL(gpuResourceID)); +} + +_MTL_INLINE void MTL::VisibleFunctionTable::setFunction(const MTL::FunctionHandle* function, NS::UInteger index) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunction_atIndex_), function, index); +} + +_MTL_INLINE void MTL::VisibleFunctionTable::setFunctions(const MTL::FunctionHandle* const functions[], NS::Range range) +{ + Object::sendMessage(this, _MTL_PRIVATE_SEL(setFunctions_withRange_), functions, range); +} diff --git a/dist/include/metal_cpp/Metal/Metal.hpp b/dist/include/metal_cpp/Metal/Metal.hpp new file mode 100644 index 0000000..0d89cc0 --- /dev/null +++ b/dist/include/metal_cpp/Metal/Metal.hpp @@ -0,0 +1,120 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// Metal/Metal.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLAccelerationStructure.hpp" +#include "MTLAccelerationStructureCommandEncoder.hpp" +#include "MTLAccelerationStructureTypes.hpp" +#include "MTLAllocation.hpp" +#include "MTLArgument.hpp" +#include "MTLArgumentEncoder.hpp" +#include "MTLBinaryArchive.hpp" +#include "MTLBlitCommandEncoder.hpp" +#include "MTLBlitPass.hpp" +#include "MTLBuffer.hpp" +#include "MTLCaptureManager.hpp" +#include "MTLCaptureScope.hpp" +#include "MTLCommandBuffer.hpp" +#include "MTLCommandEncoder.hpp" +#include "MTLCommandQueue.hpp" +#include "MTLComputeCommandEncoder.hpp" +#include "MTLComputePass.hpp" +#include "MTLComputePipeline.hpp" +#include "MTLCounters.hpp" +#include "MTLDefines.hpp" +#include "MTLDepthStencil.hpp" +#include "MTLDevice.hpp" +#include "MTLDrawable.hpp" +#include "MTLDynamicLibrary.hpp" +#include "MTLEvent.hpp" +#include "MTLFence.hpp" +#include "MTLFunctionConstantValues.hpp" +#include "MTLFunctionDescriptor.hpp" +#include "MTLFunctionHandle.hpp" +#include "MTLFunctionLog.hpp" +#include "MTLFunctionStitching.hpp" +#include "MTLHeaderBridge.hpp" +#include "MTLHeap.hpp" +#include "MTLIndirectCommandBuffer.hpp" +#include "MTLIndirectCommandEncoder.hpp" +#include "MTLIntersectionFunctionTable.hpp" +#include "MTLIOCommandBuffer.hpp" +#include "MTLIOCommandQueue.hpp" +#include "MTLIOCompressor.hpp" +#include "MTLLibrary.hpp" +#include "MTLLinkedFunctions.hpp" +#include "MTLLogState.hpp" +#include "MTLParallelRenderCommandEncoder.hpp" +#include "MTLPipeline.hpp" +#include "MTLPixelFormat.hpp" +#include "MTLPrivate.hpp" +#include "MTLRasterizationRate.hpp" +#include "MTLRenderCommandEncoder.hpp" +#include "MTLRenderPass.hpp" +#include "MTLRenderPipeline.hpp" +#include "MTLResidencySet.hpp" +#include "MTLResource.hpp" +#include "MTLResourceStateCommandEncoder.hpp" +#include "MTLResourceStatePass.hpp" +#include "MTLSampler.hpp" +#include "MTLStageInputOutputDescriptor.hpp" +#include "MTLTexture.hpp" +#include "MTLTypes.hpp" +#include "MTLVertexDescriptor.hpp" +#include "MTLVisibleFunctionTable.hpp" +#include "MTLVersion.hpp" +#include "MTLTensor.hpp" +#include "MTLResourceViewPool.hpp" +#include "MTLTextureViewPool.hpp" +#include "MTLDataType.hpp" +#include "MTL4ArgumentTable.hpp" +#include "MTL4BinaryFunction.hpp" +#include "MTL4CommandAllocator.hpp" +#include "MTL4CommandBuffer.hpp" +#include "MTL4CommandEncoder.hpp" +#include "MTL4CommandQueue.hpp" +#include "MTL4Counters.hpp" +#include "MTL4RenderPass.hpp" +#include "MTL4RenderCommandEncoder.hpp" +#include "MTL4ComputeCommandEncoder.hpp" +#include "MTL4MachineLearningCommandEncoder.hpp" +#include "MTL4Compiler.hpp" +#include "MTL4CompilerTask.hpp" +#include "MTL4LibraryDescriptor.hpp" +#include "MTL4FunctionDescriptor.hpp" +#include "MTL4LibraryFunctionDescriptor.hpp" +#include "MTL4SpecializedFunctionDescriptor.hpp" +#include "MTL4StitchedFunctionDescriptor.hpp" +#include "MTL4PipelineState.hpp" +#include "MTL4ComputePipeline.hpp" +#include "MTL4RenderPipeline.hpp" +#include "MTL4MachineLearningPipeline.hpp" +#include "MTL4TileRenderPipeline.hpp" +#include "MTL4MeshRenderPipeline.hpp" +#include "MTL4PipelineDataSetSerializer.hpp" +#include "MTL4Archive.hpp" +#include "MTL4CommitFeedback.hpp" +#include "MTL4BinaryFunctionDescriptor.hpp" +#include "MTL4LinkingDescriptor.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/MetalFX/MTL4FXFrameInterpolator.hpp b/dist/include/metal_cpp/MetalFX/MTL4FXFrameInterpolator.hpp new file mode 100644 index 0000000..1c50ec9 --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTL4FXFrameInterpolator.hpp @@ -0,0 +1,47 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTL4FXFrameInterpolator.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXFrameInterpolator.hpp" +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class FrameInterpolator : public NS::Referencing + { + public: + void encodeToCommandBuffer(MTL4::CommandBuffer* commandBuffer); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::FrameInterpolator::encodeToCommandBuffer(MTL4::CommandBuffer* commandBuffer) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/dist/include/metal_cpp/MetalFX/MTL4FXSpatialScaler.hpp b/dist/include/metal_cpp/MetalFX/MTL4FXSpatialScaler.hpp new file mode 100644 index 0000000..8ea8dfd --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTL4FXSpatialScaler.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTL4FXSpatialScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXSpatialScaler.hpp" +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class SpatialScaler : public NS::Referencing< SpatialScaler, MTLFX::SpatialScalerBase > + { + public: + void encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::SpatialScaler::encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp b/dist/include/metal_cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp new file mode 100644 index 0000000..73014bb --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTL4FXTemporalDenoisedScaler.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXTemporalDenoisedScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXTemporalDenoisedScaler.hpp" +#include "../Metal/Metal.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalDenoisedScaler : public NS::Referencing< TemporalDenoisedScaler, MTLFX::TemporalDenoisedScalerBase > + { + public: + void encodeToCommandBuffer(MTL4::CommandBuffer* commandBuffer); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::TemporalDenoisedScaler::encodeToCommandBuffer( MTL4::CommandBuffer* commandBuffer ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/dist/include/metal_cpp/MetalFX/MTL4FXTemporalScaler.hpp b/dist/include/metal_cpp/MetalFX/MTL4FXTemporalScaler.hpp new file mode 100644 index 0000000..3bda5dc --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTL4FXTemporalScaler.hpp @@ -0,0 +1,49 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTL4FXTemporalScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "MTLFXTemporalScaler.hpp" +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalScaler : public NS::Referencing< TemporalScaler, MTLFX::TemporalScalerBase > + { + public: + void encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTL4FX::TemporalScaler::encodeToCommandBuffer( MTL4::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/MetalFX/MTLFXDefines.hpp b/dist/include/metal_cpp/MetalFX/MTLFXDefines.hpp new file mode 100644 index 0000000..320e0aa --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTLFXDefines.hpp @@ -0,0 +1,41 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXDefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Foundation/NSDefines.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTLFX_EXPORT _NS_EXPORT +#define _MTLFX_EXTERN _NS_EXTERN +#define _MTLFX_INLINE _NS_INLINE +#define _MTLFX_PACKED _NS_PACKED + +#define _MTLFX_CONST( type, name ) _NS_CONST( type, name ) +#define _MTLFX_ENUM( type, name ) _NS_ENUM( type, name ) +#define _MTLFX_OPTIONS( type, name ) _NS_OPTIONS( type, name ) + +#define _MTLFX_VALIDATE_SIZE( mtlfx, name ) _NS_VALIDATE_SIZE( mtlfx, name ) +#define _MTLFX_VALIDATE_ENUM( mtlfx, name ) _NS_VALIDATE_ENUM( mtlfx, name ) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/MetalFX/MTLFXFrameInterpolator.hpp b/dist/include/metal_cpp/MetalFX/MTLFXFrameInterpolator.hpp new file mode 100644 index 0000000..10ff69c --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTLFXFrameInterpolator.hpp @@ -0,0 +1,719 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXFrameInterpolator.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" +#include "MTLFXTemporalScaler.hpp" + +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalScaler; + class TemporalDenoisedScaler; + class FrameInterpolator; +} + +namespace MTLFX +{ + class FrameInterpolatorDescriptor : public NS::Copying< FrameInterpolatorDescriptor > + { + public: + static FrameInterpolatorDescriptor* alloc(); + FrameInterpolatorDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat(MTL::PixelFormat colorTextureFormat); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat(MTL::PixelFormat outputTextureFormat); + + MTL::PixelFormat depthTextureFormat() const; + void setDepthTextureFormat(MTL::PixelFormat depthTextureFormat); + + MTL::PixelFormat motionTextureFormat() const; + void setMotionTextureFormat(MTL::PixelFormat motionTextureFormat); + + MTL::PixelFormat uiTextureFormat() const; + void setUITextureFormat(MTL::PixelFormat uiTextureFormat); + + MTLFX::FrameInterpolatableScaler* scaler() const; + void setScaler(MTLFX::FrameInterpolatableScaler* scaler); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger inputWidth ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger inputHeight ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger outputWidth ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger outputHeight ); + + class FrameInterpolator* newFrameInterpolator( const MTL::Device* pDevice) const; + MTL4FX::FrameInterpolator* newFrameInterpolator( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler) const; + + static bool supportsMetal4FX(MTL::Device* device); + static bool supportsDevice(MTL::Device* device); + }; + + class FrameInterpolatorBase : public NS::Referencing + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + MTL::TextureUsage depthTextureUsage() const; + MTL::TextureUsage motionTextureUsage() const; + MTL::TextureUsage uiTextureUsage() const; + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat depthTextureFormat() const; + MTL::PixelFormat motionTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + MTL::PixelFormat uiTextureFormat() const; + + MTL::Texture* colorTexture() const; + void setColorTexture(MTL::Texture* colorTexture); + + MTL::Texture* prevColorTexture() const; + void setPrevColorTexture(MTL::Texture* prevColorTexture); + + MTL::Texture* depthTexture() const; + void setDepthTexture(MTL::Texture* depthTexture); + + MTL::Texture* motionTexture() const; + void setMotionTexture(MTL::Texture* motionTexture); + + float motionVectorScaleX() const; + void setMotionVectorScaleX(float scaleX); + + float motionVectorScaleY() const; + void setMotionVectorScaleY(float scaleY); + + float deltaTime() const; + void setDeltaTime( float deltaTime ); + + float nearPlane() const; + void setNearPlane( float nearPlane ); + + float farPlane() const; + void setFarPlane( float farPlane ); + + float fieldOfView() const; + void setFieldOfView( float fieldOfView ); + + float aspectRatio() const; + void setAspectRatio( float aspectRatio ); + + MTL::Texture* uiTexture() const; + void setUITexture(MTL::Texture* uiTexture); + + float jitterOffsetX() const; + void setJitterOffsetX( float jitterOffsetX ); + + float jitterOffsetY() const; + void setJitterOffsetY( float jitterOffsetY ); + + bool isUITextureComposited() const; + void setIsUITextureComposited( bool uiTextureComposited ); + + bool shouldResetHistory() const; + void setShouldResetHistory( bool shouldResetHistory ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* outputTexture ); + + MTL::Fence* fence() const; + void setFence( MTL::Fence* fence ); + + bool isDepthReversed() const; + void setDepthReversed( bool depthReversed ); + }; + + class FrameInterpolator : public NS::Referencing + { + public: + void encodeToCommandBuffer(MTL::CommandBuffer* commandBuffer); + }; + +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolatorDescriptor* MTLFX::FrameInterpolatorDescriptor::alloc() +{ + return NS::Object::alloc< FrameInterpolatorDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXFrameInterpolatorDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolatorDescriptor* MTLFX::FrameInterpolatorDescriptor::init() +{ + return NS::Object::init< FrameInterpolatorDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setColorTextureFormat( MTL::PixelFormat colorTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), colorTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setOutputTextureFormat( MTL::PixelFormat outputTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), outputTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setDepthTextureFormat( MTL::PixelFormat depthTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTextureFormat_ ), depthTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setMotionTextureFormat( MTL::PixelFormat motionTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTextureFormat_ ), motionTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorDescriptor::uiTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( uiTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setUITextureFormat( MTL::PixelFormat uiTextureFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setUITextureFormat_ ), uiTextureFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolatableScaler* MTLFX::FrameInterpolatorDescriptor::scaler() const +{ + return NS::Object::sendMessage< MTLFX::FrameInterpolatableScaler* >( this, _MTLFX_PRIVATE_SEL( scaler ) ); +} + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setScaler(MTLFX::FrameInterpolatableScaler* scaler) +{ + NS::Object::sendMessage< void >(this, _MTLFX_PRIVATE_SEL( setScaler_ ), scaler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setInputWidth( NS::UInteger inputWidth ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), inputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setInputHeight( NS::UInteger inputHeight ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), inputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setOutputWidth( NS::UInteger outputWidth ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), outputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorDescriptor::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorDescriptor::setOutputHeight( NS::UInteger outputHeight ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), outputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::FrameInterpolator* MTLFX::FrameInterpolatorDescriptor::newFrameInterpolator( const MTL::Device* device ) const +{ + return NS::Object::sendMessage< MTLFX::FrameInterpolator* >( this, _MTLFX_PRIVATE_SEL( newFrameInterpolatorWithDevice_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::FrameInterpolator* MTLFX::FrameInterpolatorDescriptor::newFrameInterpolator( const MTL::Device* device, const MTL4::Compiler* compiler ) const +{ + return NS::Object::sendMessage< MTL4FX::FrameInterpolator* >( this, _MTLFX_PRIVATE_SEL( newFrameInterpolatorWithDevice_compiler_ ), device, compiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorDescriptor::supportsMetal4FX(MTL::Device* device) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXFrameInterpolatorDescriptor), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorDescriptor::supportsDevice(MTL::Device* device) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXFrameInterpolatorDescriptor), _MTLFX_PRIVATE_SEL( supportsDevice_ ), device ); +} + + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::colorTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::outputTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::depthTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( depthTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::motionTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( motionTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::FrameInterpolatorBase::uiTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( uiTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::FrameInterpolatorBase::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::FrameInterpolatorBase::uiTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( uiTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::colorTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setColorTexture(MTL::Texture* colorTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), colorTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::prevColorTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( prevColorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setPrevColorTexture(MTL::Texture* prevColorTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setPrevColorTexture_ ), prevColorTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::depthTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( depthTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setDepthTexture(MTL::Texture* depthTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTexture_ ), depthTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::motionTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( motionTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setMotionTexture(MTL::Texture* motionTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTexture_ ), motionTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::motionVectorScaleX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setMotionVectorScaleX(float scaleX) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleX_ ), scaleX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::motionVectorScaleY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setMotionVectorScaleY(float scaleY) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleY_ ), scaleY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::deltaTime() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( deltaTime ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setDeltaTime( float deltaTime ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDeltaTime_ ), deltaTime ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::nearPlane() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( nearPlane ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setNearPlane( float nearPlane ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setNearPlane_ ), nearPlane ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::farPlane() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( farPlane ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setFarPlane( float farPlane ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFarPlane_ ), farPlane ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::fieldOfView() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( fieldOfView ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setFieldOfView( float fieldOfView ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFieldOfView_ ), fieldOfView ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::aspectRatio() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( aspectRatio ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setAspectRatio( float aspectRatio ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setAspectRatio_ ), aspectRatio ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::uiTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( uiTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setUITexture(MTL::Texture* uiTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setUITexture_ ), uiTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::jitterOffsetX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setJitterOffsetX( float jitterOffsetX ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetX_ ), jitterOffsetX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::FrameInterpolatorBase::jitterOffsetY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setJitterOffsetY( float jitterOffsetY ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetY_ ), jitterOffsetY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorBase::isUITextureComposited() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isUITextureComposited ) ); +} + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setIsUITextureComposited( bool uiTextureComposited ) +{ + NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setIsUITextureComposited_ ), uiTextureComposited ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorBase::shouldResetHistory() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( shouldResetHistory ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setShouldResetHistory(bool shouldResetHistory) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setShouldResetHistory_ ), shouldResetHistory ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::FrameInterpolatorBase::outputTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setOutputTexture(MTL::Texture* outputTexture) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), outputTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::FrameInterpolatorBase::fence() const +{ + return NS::Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setFence(MTL::Fence* fence) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), fence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::FrameInterpolatorBase::isDepthReversed() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDepthReversed ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolatorBase::setDepthReversed(bool depthReversed) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthReversed_ ), depthReversed ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::FrameInterpolator::encodeToCommandBuffer(MTL::CommandBuffer* commandBuffer) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/dist/include/metal_cpp/MetalFX/MTLFXPrivate.hpp b/dist/include/metal_cpp/MetalFX/MTLFXPrivate.hpp new file mode 100644 index 0000000..21fd728 --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTLFXPrivate.hpp @@ -0,0 +1,482 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _MTLFX_PRIVATE_CLS( symbol ) ( MTLFX::Private::Class::s_k##symbol ) +#define _MTLFX_PRIVATE_SEL( accessor ) ( MTLFX::Private::Selector::s_k##accessor ) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined( MTLFX_PRIVATE_IMPLEMENTATION ) + +#if defined( METALCPP_SYMBOL_VISIBILITY_HIDDEN ) +#define _MTLFX_PRIVATE_VISIBILITY __attribute__( ( visibility("hidden" ) ) ) +#else +#define _MTLFX_PRIVATE_VISIBILITY __attribute__( ( visibility("default" ) ) ) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _MTLFX_PRIVATE_IMPORT __attribute__( ( weak_import ) ) + +#ifdef __OBJC__ +#define _MTLFX_PRIVATE_OBJC_LOOKUP_CLASS( symbol ) ( ( __bridge void* ) objc_lookUpClass( #symbol ) ) +#define _MTLFX_PRIVATE_OBJC_GET_PROTOCOL( symbol ) ( ( __bridge void* ) objc_getProtocol( #symbol ) ) +#else +#define _MTLFX_PRIVATE_OBJC_LOOKUP_CLASS( symbol ) objc_lookUpClass(#symbol) +#define _MTLFX_PRIVATE_OBJC_GET_PROTOCOL( symbol ) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _MTLFX_PRIVATE_DEF_CLS( symbol ) void* s_k##symbol _MTLFX_PRIVATE_VISIBILITY = _MTLFX_PRIVATE_OBJC_LOOKUP_CLASS( symbol ) +#define _MTLFX_PRIVATE_DEF_PRO( symbol ) void* s_k##symbol _MTLFX_PRIVATE_VISIBILITY = _MTLFX_PRIVATE_OBJC_GET_PROTOCOL( symbol ) +#define _MTLFX_PRIVATE_DEF_SEL( accessor, symbol ) SEL s_k##accessor _MTLFX_PRIVATE_VISIBILITY = sel_registerName( symbol ) + +#include +#define MTLFX_DEF_FUNC( name, signature ) using Fn##name = signature; \ + Fn##name name = reinterpret_cast< Fn##name >( dlsym( RTLD_DEFAULT, #name ) ) + +namespace MTLFX::Private +{ + template + + inline _Type const LoadSymbol(const char* pSymbol) + { + const _Type* pAddress = static_cast<_Type*>(dlsym(RTLD_DEFAULT, pSymbol)); + + return pAddress ? *pAddress : nullptr; + } +} // MTLFX::Private + +#if defined(__MAC_26_0) || defined(__IPHONE_26_0) || defined(__TVOS_26_0) + +#define _MTLFX_PRIVATE_DEF_STR( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##symbol _MTLFX_PRIVATE_IMPORT; \ + type const MTLFX::symbol = ( nullptr != &MTLFX##symbol ) ? MTLFX##ssymbol : nullptr + +#define _MTLFX_PRIVATE_DEF_CONST( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol _MTLFX_PRIVATE_IMPORT; \ + type const MTLFX::symbol = (nullptr != &MTLFX##ssymbol) ? MTLFX##ssymbol : nullptr + +#define _MTLFX_PRIVATE_DEF_WEAK_CONST( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol; \ + type const MTLFX::symbol = Private::LoadSymbol< type >( "MTLFX" #symbol ) + +#else + +#define _MTLFX_PRIVATE_DEF_STR( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol; \ + type const MTLFX::symbol = Private::LoadSymbol< type >( "MTLFX" #symbol ) + +#define _MTLFX_PRIVATE_DEF_CONST( type, symbol ) \ + _MTLFX_EXTERN type const MTLFX##ssymbol; \ + type const MTLFX::symbol = Private::LoadSymbol< type >( "MTLFX" #symbol ) + +#define _MTLFX_PRIVATE_DEF_WEAK_CONST( type, symbol ) _MTLFX_PRIVATE_DEF_CONST( type, symbol ) + +#endif + +#else + +#define _MTLFX_PRIVATE_DEF_CLS( symbol ) extern void* s_k##symbol +#define _MTLFX_PRIVATE_DEF_PRO( symbol ) extern void* s_k##symbol +#define _MTLFX_PRIVATE_DEF_SEL( accessor, symbol ) extern SEL s_k##accessor +#define _MTLFX_PRIVATE_DEF_STR( type, symbol ) extern type const MTLFX::symbol +#define _MTLFX_PRIVATE_DEF_CONST( type, symbol ) extern type const MTLFX::symbol +#define _MTLFX_PRIVATE_DEF_WEAK_CONST( type, symbol ) extern type const MTLFX::symbol + +#endif // MTLFX_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTLFX +{ + namespace Private + { + namespace Class + { + _MTLFX_PRIVATE_DEF_CLS( MTLFXSpatialScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTLFXTemporalScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTLFXFrameInterpolatorDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTLFXTemporalDenoisedScalerDescriptor ); + + _MTLFX_PRIVATE_DEF_CLS( MTL4FXSpatialScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTL4FXTemporalScalerDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTL4FXFrameInterpolatorDescriptor ); + _MTLFX_PRIVATE_DEF_CLS( MTL4FXTemporalDenoisedScalerDescriptor ); + } // Class + } // Private +} // MTLFX + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTLFX +{ + namespace Private + { + namespace Protocol + { + _MTLFX_PRIVATE_DEF_PRO( MTLFXSpatialScaler ); + _MTLFX_PRIVATE_DEF_PRO( MTLFXTemporalScaler ); + } // Protocol + } // Private +} // MTLFX + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTLFX +{ + namespace Private + { + namespace Selector + { + _MTLFX_PRIVATE_DEF_SEL( aspectRatio, + "aspectRatio" ); + _MTLFX_PRIVATE_DEF_SEL( colorProcessingMode, + "colorProcessingMode" ); + _MTLFX_PRIVATE_DEF_SEL( colorTexture, + "colorTexture" ); + _MTLFX_PRIVATE_DEF_SEL( colorTextureFormat, + "colorTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( colorTextureUsage, + "colorTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( deltaTime, + "deltaTime" ); + _MTLFX_PRIVATE_DEF_SEL( denoiseStrengthMaskTexture, + "denoiseStrengthMaskTexture" ); + _MTLFX_PRIVATE_DEF_SEL( denoiseStrengthMaskTextureFormat, + "denoiseStrengthMaskTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( denoiseStrengthMaskTextureUsage, + "denoiseStrengthMaskTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( depthTexture, + "depthTexture" ); + _MTLFX_PRIVATE_DEF_SEL( depthTextureFormat, + "depthTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( depthTextureUsage, + "depthTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( diffuseAlbedoTexture, + "diffuseAlbedoTexture" ); + _MTLFX_PRIVATE_DEF_SEL( diffuseAlbedoTextureFormat, + "diffuseAlbedoTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( diffuseAlbedoTextureUsage, + "diffuseAlbedoTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( encodeToCommandBuffer_, + "encodeToCommandBuffer:" ); + _MTLFX_PRIVATE_DEF_SEL( exposureTexture, + "exposureTexture" ); + _MTLFX_PRIVATE_DEF_SEL( farPlane, + "farPlane" ); + _MTLFX_PRIVATE_DEF_SEL( fence, + "fence" ); + _MTLFX_PRIVATE_DEF_SEL( fieldOfView, + "fieldOfView" ); + _MTLFX_PRIVATE_DEF_SEL( height, + "height" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentHeight, + "inputContentHeight" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentMaxScale, + "inputContentMaxScale" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentMinScale, + "inputContentMinScale" ); + _MTLFX_PRIVATE_DEF_SEL( inputContentWidth, + "inputContentWidth" ); + _MTLFX_PRIVATE_DEF_SEL( inputHeight, + "inputHeight" ); + _MTLFX_PRIVATE_DEF_SEL( inputWidth, + "inputWidth" ); + _MTLFX_PRIVATE_DEF_SEL( isAutoExposureEnabled, + "isAutoExposureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isDenoiseStrengthMaskTextureEnabled, + "isDenoiseStrengthMaskTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isDepthReversed, + "isDepthReversed" ); + _MTLFX_PRIVATE_DEF_SEL( isInputContentPropertiesEnabled, + "isInputContentPropertiesEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isTransparencyOverlayTextureEnabled, + "isTransparencyOverlayTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isReactiveMaskTextureEnabled, + "isReactiveMaskTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isSpecularHitDistanceTextureEnabled, + "isSpecularHitDistanceTextureEnabled" ); + _MTLFX_PRIVATE_DEF_SEL( isUITextureComposited, + "isUITextureComposited" ); + _MTLFX_PRIVATE_DEF_SEL( jitterOffsetX, + "jitterOffsetX" ); + _MTLFX_PRIVATE_DEF_SEL( jitterOffsetY, + "jitterOffsetY" ); + _MTLFX_PRIVATE_DEF_SEL( maskTexture, + "maskTexture" ); + _MTLFX_PRIVATE_DEF_SEL( maskTextureFormat, + "maskTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( maskTextureUsage, + "maskTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( motionTexture, + "motionTexture" ); + _MTLFX_PRIVATE_DEF_SEL( motionTextureFormat, + "motionTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( motionTextureUsage, + "motionTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( motionVectorScaleX, + "motionVectorScaleX" ); + _MTLFX_PRIVATE_DEF_SEL( motionVectorScaleY, + "motionVectorScaleY" ); + _MTLFX_PRIVATE_DEF_SEL( nearPlane, + "nearPlane" ); + _MTLFX_PRIVATE_DEF_SEL( newFrameInterpolatorWithDevice_, + "newFrameInterpolatorWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newFrameInterpolatorWithDevice_compiler_, + "newFrameInterpolatorWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalDenoisedScalerWithDevice_, + "newTemporalDenoisedScalerWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalDenoisedScalerWithDevice_compiler_, + "newTemporalDenoisedScalerWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( newSpatialScalerWithDevice_, + "newSpatialScalerWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newSpatialScalerWithDevice_compiler_, + "newSpatialScalerWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalScalerWithDevice_, + "newTemporalScalerWithDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( newTemporalScalerWithDevice_compiler_, + "newTemporalScalerWithDevice:compiler:" ); + _MTLFX_PRIVATE_DEF_SEL( normalTexture, + "normalTexture" ); + _MTLFX_PRIVATE_DEF_SEL( normalTextureFormat, + "normalTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( normalTextureUsage, + "normalTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( outputHeight, + "outputHeight" ); + _MTLFX_PRIVATE_DEF_SEL( outputTexture, + "outputTexture" ); + _MTLFX_PRIVATE_DEF_SEL( outputTextureFormat, + "outputTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( outputTextureUsage, + "outputTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( outputWidth, + "outputWidth" ); + _MTLFX_PRIVATE_DEF_SEL( preExposure, + "preExposure" ); + _MTLFX_PRIVATE_DEF_SEL( transparencyOverlayTextureFormat, + "transparencyOverlayTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( transparencyOverlayTextureUsage, + "transparencyOverlayTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( prevColorTexture, + "prevColorTexture" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveMaskTextureFormat, + "reactiveMaskTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveTextureUsage, + "reactiveTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( reactiveMaskTexture, + "reactiveMaskTexture" ); + _MTLFX_PRIVATE_DEF_SEL( reset, + "reset" ); + _MTLFX_PRIVATE_DEF_SEL( requiresSynchronousInitialization, + "requiresSynchronousInitialization" ); + _MTLFX_PRIVATE_DEF_SEL( roughnessTextureFormat, + "roughnessTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( roughnessTextureUsage, + "roughnessTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( scaler, + "scaler" ); + _MTLFX_PRIVATE_DEF_SEL( scaler4, + "scaler4" ); + _MTLFX_PRIVATE_DEF_SEL( setAspectRatio_, + "setAspectRatio:" ); + _MTLFX_PRIVATE_DEF_SEL( setAutoExposureEnabled_, + "setAutoExposureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setColorProcessingMode_, + "setColorProcessingMode:" ); + _MTLFX_PRIVATE_DEF_SEL( setColorTexture_, + "setColorTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setColorTextureFormat_, + "setColorTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setDeltaTime_, + "setDeltaTime:" ); + _MTLFX_PRIVATE_DEF_SEL( setDenoiseStrengthMaskTexture_, + "setDenoiseStrengthMaskTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setDenoiseStrengthMaskTextureEnabled_, + "setDenoiseStrengthMaskTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setDenoiseStrengthMaskTextureFormat_, + "setDenoiseStrengthMaskTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthInverted_, + "setDepthInverted:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthReversed_, + "setDepthReversed:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthTexture_, + "setDepthTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setDepthTextureFormat_, + "setDepthTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setDiffuseAlbedoTexture_, + "setDiffuseAlbedoTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setDiffuseAlbedoTextureFormat_, + "setDiffuseAlbedoTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setExposureTexture_, + "setExposureTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setFarPlane_, + "setFarPlane:" ); + _MTLFX_PRIVATE_DEF_SEL( setFence_, + "setFence:" ); + _MTLFX_PRIVATE_DEF_SEL( setFieldOfView_, + "setFieldOfView:" ); + _MTLFX_PRIVATE_DEF_SEL( setHeight_, + "setHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentHeight_, + "setInputContentHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentMaxScale_, + "setInputContentMaxScale:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentMinScale_, + "setInputContentMinScale:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentPropertiesEnabled_, + "setInputContentPropertiesEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputContentWidth_, + "setInputContentWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputHeight_, + "setInputHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setInputWidth_, + "setInputWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( setIsUITextureComposited_, + "setIsUITextureComposited:" ); + _MTLFX_PRIVATE_DEF_SEL( setJitterOffsetX_, + "setJitterOffsetX:" ); + _MTLFX_PRIVATE_DEF_SEL( setJitterOffsetY_, + "setJitterOffsetY:" ); + _MTLFX_PRIVATE_DEF_SEL( setNearPlane_, + "setNearPlane:" ); + _MTLFX_PRIVATE_DEF_SEL( setMaskTexture_, + "setMaskTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setMaskTextureFormat_, + "setMaskTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionTexture_, + "setMotionTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionTextureFormat_, + "setMotionTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionVectorScaleX_, + "setMotionVectorScaleX:" ); + _MTLFX_PRIVATE_DEF_SEL( setMotionVectorScaleY_, + "setMotionVectorScaleY:" ); + _MTLFX_PRIVATE_DEF_SEL( setNormalTexture_, + "setNormalTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setNormalTextureFormat_, + "setNormalTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputHeight_, + "setOutputHeight:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputTexture_, + "setOutputTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputTextureFormat_, + "setOutputTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setOutputWidth_, + "setOutputWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( transparencyOverlayTexture, + "transparencyOverlayTexture" ); + _MTLFX_PRIVATE_DEF_SEL( setTransparencyOverlayTexture_, + "setTransparencyOverlayTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setTransparencyOverlayTextureEnabled_, + "setTransparencyOverlayTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setPreExposure_, + "setPreExposure:" ); + _MTLFX_PRIVATE_DEF_SEL( setTransparencyOverlayTextureFormat_, + "setTransparencyOverlayTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setPrevColorTexture_, + "setPrevColorTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTexture_, + "setReactiveMaskTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTextureEnabled_, + "setReactiveMaskTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setReactiveMaskTextureFormat_, + "setReactiveMaskTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setRequiresSynchronousInitialization_, + "setRequiresSynchronousInitialization:" ); + _MTLFX_PRIVATE_DEF_SEL( setReset_, + "setReset:" ); + _MTLFX_PRIVATE_DEF_SEL( roughnessTexture, + "roughnessTexture" ); + _MTLFX_PRIVATE_DEF_SEL( setRoughnessTexture_, + "setRoughnessTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setRoughnessTextureFormat_, + "setRoughnessTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setScaler_, + "setScaler:" ); + _MTLFX_PRIVATE_DEF_SEL( setShouldResetHistory_, + "setShouldResetHistory:" ); + _MTLFX_PRIVATE_DEF_SEL( specularHitDistanceTexture, + "specularHitDistanceTexture" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularHitDistanceTexture_, + "setSpecularHitDistanceTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularHitDistanceTextureEnabled_, + "setSpecularHitDistanceTextureEnabled:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularAlbedoTexture_, + "setSpecularAlbedoTexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularAlbedoTextureFormat_, + "setSpecularAlbedoTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setSpecularHitDistanceTextureFormat_, + "setSpecularHitDistanceTextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setUITexture_, + "setUITexture:" ); + _MTLFX_PRIVATE_DEF_SEL( setUITextureFormat_, + "setUITextureFormat:" ); + _MTLFX_PRIVATE_DEF_SEL( setViewToClipMatrix_, + "setViewToClipMatrix:" ); + _MTLFX_PRIVATE_DEF_SEL( setWidth_, + "setWidth:" ); + _MTLFX_PRIVATE_DEF_SEL( setWorldToViewMatrix_, + "setWorldToViewMatrix:" ); + _MTLFX_PRIVATE_DEF_SEL( shouldResetHistory, + "shouldResetHistory" ); + _MTLFX_PRIVATE_DEF_SEL( specularAlbedoTexture, + "specularAlbedoTexture" ); + _MTLFX_PRIVATE_DEF_SEL( specularAlbedoTextureFormat, + "specularAlbedoTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( specularAlbedoTextureUsage, + "specularAlbedoTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( specularHitDistanceTextureFormat, + "specularHitDistanceTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( specularHitDistanceTextureUsage, + "specularHitDistanceTextureUsage" ); + _MTLFX_PRIVATE_DEF_SEL( supportedInputContentMaxScaleForDevice_, + "supportedInputContentMaxScaleForDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( supportedInputContentMinScaleForDevice_, + "supportedInputContentMinScaleForDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( supportsDevice_, + "supportsDevice:" ); + _MTLFX_PRIVATE_DEF_SEL( supportsMetal4FX_, + "supportsMetal4FX:" ); + _MTLFX_PRIVATE_DEF_SEL( uiTexture, + "uiTexture" ); + _MTLFX_PRIVATE_DEF_SEL( uiTextureFormat, + "uiTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( uiTextureUsage, + "uiTextureFormat" ); + _MTLFX_PRIVATE_DEF_SEL( viewToClipMatrix, + "viewToClipMatrix" ); + _MTLFX_PRIVATE_DEF_SEL( width, + "width" ); + _MTLFX_PRIVATE_DEF_SEL( worldToViewMatrix, + "worldToViewMatrix" ); + } // Selector + } // Private +} // MTLFX + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/MetalFX/MTLFXSpatialScaler.hpp b/dist/include/metal_cpp/MetalFX/MTLFXSpatialScaler.hpp new file mode 100644 index 0000000..cb1186e --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTLFXSpatialScaler.hpp @@ -0,0 +1,397 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXSpatialScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class SpatialScaler; +} + +namespace MTLFX +{ + _MTLFX_ENUM( NS::Integer, SpatialScalerColorProcessingMode ) + { + SpatialScalerColorProcessingModePerceptual = 0, + SpatialScalerColorProcessingModeLinear = 1, + SpatialScalerColorProcessingModeHDR = 2 + }; + + class SpatialScalerDescriptor : public NS::Copying< SpatialScalerDescriptor > + { + public: + static class SpatialScalerDescriptor* alloc(); + class SpatialScalerDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat( MTL::PixelFormat format ); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger width ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger height ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger width ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger height ); + + SpatialScalerColorProcessingMode colorProcessingMode() const; + void setColorProcessingMode( SpatialScalerColorProcessingMode mode ); + + class SpatialScaler* newSpatialScaler( const MTL::Device* pDevice ) const; + MTL4FX::SpatialScaler* newSpatialScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler ) const; + + static bool supportsDevice( const MTL::Device* pDevice); + static bool supportsMetal4FX( const MTL::Device* pDevice ); + }; + + class SpatialScalerBase : public NS::Referencing< SpatialScaler > + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + + NS::UInteger inputContentWidth() const; + void setInputContentWidth( NS::UInteger width ); + + NS::UInteger inputContentHeight() const; + void setInputContentHeight( NS::UInteger height ); + + MTL::Texture* colorTexture() const; + void setColorTexture( MTL::Texture* pTexture ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* pTexture ); + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + SpatialScalerColorProcessingMode colorProcessingMode() const; + + MTL::Fence* fence() const; + void setFence( MTL::Fence* pFence ); + }; + + class SpatialScaler : public NS::Referencing< SpatialScaler, SpatialScalerBase > + { + public: + void encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerDescriptor* MTLFX::SpatialScalerDescriptor::alloc() +{ + return NS::Object::alloc< SpatialScalerDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXSpatialScalerDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerDescriptor* MTLFX::SpatialScalerDescriptor::init() +{ + return NS::Object::init< SpatialScalerDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerDescriptor::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setColorTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerDescriptor::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setOutputTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setInputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setInputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setOutputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerDescriptor::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setOutputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerColorProcessingMode MTLFX::SpatialScalerDescriptor::colorProcessingMode() const +{ + return Object::sendMessage< SpatialScalerColorProcessingMode >( this, _MTLFX_PRIVATE_SEL( colorProcessingMode ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerDescriptor::setColorProcessingMode( SpatialScalerColorProcessingMode mode ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorProcessingMode_ ), mode ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScaler* MTLFX::SpatialScalerDescriptor::newSpatialScaler( const MTL::Device* pDevice ) const +{ + return Object::sendMessage< SpatialScaler* >( this, _MTLFX_PRIVATE_SEL( newSpatialScalerWithDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::SpatialScaler* MTLFX::SpatialScalerDescriptor::newSpatialScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler ) const +{ + return Object::sendMessage< MTL4FX::SpatialScaler* >( this, _MTLFX_PRIVATE_SEL( newSpatialScalerWithDevice_compiler_ ), pDevice, pCompiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::SpatialScalerDescriptor::supportsDevice( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXSpatialScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::SpatialScalerDescriptor::supportsMetal4FX( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXSpatialScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::SpatialScalerBase::colorTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::SpatialScalerBase::outputTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputContentWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setInputContentWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputContentHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setInputContentHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::SpatialScalerBase::colorTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setColorTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::SpatialScalerBase::outputTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setOutputTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerBase::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::SpatialScalerBase::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::SpatialScalerBase::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::SpatialScalerColorProcessingMode MTLFX::SpatialScalerBase::colorProcessingMode() const +{ + return Object::sendMessage< SpatialScalerColorProcessingMode >( this, _MTLFX_PRIVATE_SEL( colorProcessingMode ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::SpatialScalerBase::fence() const +{ + return Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScalerBase::setFence( MTL::Fence* pFence ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), pFence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::SpatialScaler::encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp b/dist/include/metal_cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp new file mode 100644 index 0000000..5863e07 --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTLFXTemporalDenoisedScaler.hpp @@ -0,0 +1,1208 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXTemporalDenoisedScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" +#include "MTLFXTemporalScaler.hpp" + +#include "../Metal/Metal.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalDenoisedScaler; +} + +namespace MTLFX +{ + class TemporalDenoisedScalerDescriptor : public NS::Copying< TemporalDenoisedScalerDescriptor > + { + public: + static class TemporalDenoisedScalerDescriptor* alloc(); + class TemporalDenoisedScalerDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat depthTextureFormat() const; + void setDepthTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat motionTextureFormat() const; + void setMotionTextureFormat( MTL::PixelFormat pixelFormal ); + + MTL::PixelFormat diffuseAlbedoTextureFormat() const; + void setDiffuseAlbedoTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat specularAlbedoTextureFormat() const; + void setSpecularAlbedoTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat normalTextureFormat() const; + void setNormalTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat roughnessTextureFormat() const; + void setRoughnessTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat specularHitDistanceTextureFormat() const; + void setSpecularHitDistanceTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat denoiseStrengthMaskTextureFormat() const; + void setDenoiseStrengthMaskTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat transparencyOverlayTextureFormat() const; + void setTransparencyOverlayTextureFormat( MTL::PixelFormat pixelFormat ); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat( MTL::PixelFormat pixelFormat ); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger inputWidth ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger inputHeight ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger outputWidth ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger outputHeight ); + + bool requiresSynchronousInitialization() const; + void setRequiresSynchronousInitialization( bool requiresSynchronousInitialization ); + + bool isAutoExposureEnabled() const; + void setAutoExposureEnabled( bool enabled ); + + bool isInputContentPropertiesEnabled() const; + void setInputContentPropertiesEnabled( bool enabled ); + + float inputContentMinScale() const; + void setInputContentMinScale( float inputContentMinScale ); + + float inputContentMaxScale() const; + void setInputContentMaxScale( float inputContentMaxScale ); + + bool isReactiveMaskTextureEnabled() const; + void setReactiveMaskTextureEnabled( bool enabled ); + + MTL::PixelFormat reactiveMaskTextureFormat() const; + void setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ); + + bool isSpecularHitDistanceTextureEnabled() const; + void setSpecularHitDistanceTextureEnabled( bool enabled ); + + bool isDenoiseStrengthMaskTextureEnabled() const; + void setDenoiseStrengthMaskTextureEnabled( bool enabled ); + + bool isTransparencyOverlayTextureEnabled() const; + void setTransparencyOverlayTextureEnabled( bool enabled ); + + class TemporalDenoisedScaler* newTemporalDenoisedScaler( const MTL::Device* device ) const; + MTL4FX::TemporalDenoisedScaler* newTemporalDenoisedScaler( const MTL::Device* device, const MTL4::Compiler* compiler) const; + + static float supportedInputContentMinScale(MTL::Device* device); + static float supportedInputContentMaxScale(MTL::Device* device); + + static bool supportsMetal4FX( MTL::Device* device); + static bool supportsDevice( MTL::Device* device); + }; + + class TemporalDenoisedScalerBase : public NS::Referencing< TemporalDenoisedScalerBase, FrameInterpolatableScaler > + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage depthTextureUsage() const; + MTL::TextureUsage motionTextureUsage() const; + MTL::TextureUsage reactiveTextureUsage() const; + MTL::TextureUsage diffuseAlbedoTextureUsage() const; + MTL::TextureUsage specularAlbedoTextureUsage() const; + MTL::TextureUsage normalTextureUsage() const; + MTL::TextureUsage roughnessTextureUsage() const; + MTL::TextureUsage specularHitDistanceTextureUsage() const; + MTL::TextureUsage denoiseStrengthMaskTextureUsage() const; + MTL::TextureUsage transparencyOverlayTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + + MTL::Texture* colorTexture() const; + void setColorTexture( MTL::Texture* colorTexture ); + + MTL::Texture* depthTexture() const; + void setDepthTexture( MTL::Texture* depthTexture ); + + MTL::Texture* motionTexture() const; + void setMotionTexture( MTL::Texture* motionTexture ); + + MTL::Texture* diffuseAlbedoTexture() const; + void setDiffuseAlbedoTexture( MTL::Texture* diffuseAlbedoTexture ); + + MTL::Texture* specularAlbedoTexture() const; + void setSpecularAlbedoTexture( MTL::Texture* specularAlbedoTexture ); + + MTL::Texture* normalTexture() const; + void setNormalTexture( MTL::Texture* normalTexture ); + + MTL::Texture* roughnessTexture() const; + void setRoughnessTexture( MTL::Texture* roughnessTexture ); + + MTL::Texture* specularHitDistanceTexture() const; + void setSpecularHitDistanceTexture( MTL::Texture* specularHitDistanceTexture ); + + MTL::Texture* denoiseStrengthMaskTexture() const; + void setDenoiseStrengthMaskTexture( MTL::Texture* denoiseStrengthMaskTexture ); + + MTL::Texture* transparencyOverlayTexture() const; + void setTransparencyOverlayTexture( MTL::Texture* transparencyOverlayTexture ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* outputTexture ); + + MTL::Texture* exposureTexture() const; + void setExposureTexture( MTL::Texture* exposureTexture ); + + float preExposure() const; + void setPreExposure( float preExposure ); + + MTL::Texture* reactiveMaskTexture() const; + void setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ); + + float jitterOffsetX() const; + void setJitterOffsetX( float jitterOffsetX ); + + float jitterOffsetY() const; + void setJitterOffsetY( float jitterOffsetY ); + + float motionVectorScaleX() const; + void setMotionVectorScaleX( float motionVectorScaleX ); + + float motionVectorScaleY() const; + void setMotionVectorScaleY( float motionVectorScaleY ); + + bool shouldResetHistory() const; + void setShouldResetHistory( bool shouldResetHistory ); + + bool isDepthReversed() const; + void setDepthReversed( bool depthReversed ); + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat depthTextureFormat() const; + MTL::PixelFormat motionTextureFormat() const; + MTL::PixelFormat diffuseAlbedoTextureFormat() const; + MTL::PixelFormat specularAlbedoTextureFormat() const; + MTL::PixelFormat normalTextureFormat() const; + MTL::PixelFormat roughnessTextureFormat() const; + MTL::PixelFormat specularHitDistanceTextureFormat() const; + MTL::PixelFormat denoiseStrengthMaskTextureFormat() const; + MTL::PixelFormat transparencyOverlayTextureFormat() const; + MTL::PixelFormat reactiveMaskTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + float inputContentMinScale() const; + float inputContentMaxScale() const; + + simd::float4x4 worldToViewMatrix() const; + void setWorldToViewMatrix( simd::float4x4 worldToViewMatrix ); + + simd::float4x4 viewToClipMatrix() const; + void setViewToClipMatrix( simd::float4x4 viewToClipMatrix ); + + MTL::Fence* fence() const; + void setFence( MTL::Fence* fence ); + }; + + class TemporalDenoisedScaler : public NS::Referencing< TemporalDenoisedScaler, TemporalDenoisedScalerBase > + { + public: + + void encodeToCommandBuffer(MTL::CommandBuffer* commandBuffer); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalDenoisedScalerDescriptor* MTLFX::TemporalDenoisedScalerDescriptor::alloc() +{ + return NS::Object::alloc< TemporalDenoisedScalerDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalDenoisedScalerDescriptor* MTLFX::TemporalDenoisedScalerDescriptor::init() +{ + return NS::Object::init< TemporalDenoisedScalerDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setColorTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDepthTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setMotionTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::diffuseAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDiffuseAlbedoTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDiffuseAlbedoTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::specularAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setSpecularAlbedoTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularAlbedoTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::normalTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( normalTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setNormalTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setNormalTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::roughnessTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( roughnessTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setRoughnessTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRoughnessTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::specularHitDistanceTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setSpecularHitDistanceTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularHitDistanceTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::denoiseStrengthMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDenoiseStrengthMaskTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDenoiseStrengthMaskTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::transparencyOverlayTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setTransparencyOverlayTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setTransparencyOverlayTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setOutputTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputWidth( NS::UInteger inputWidth ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), inputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputHeight( NS::UInteger inputHeight ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), inputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setOutputWidth( NS::UInteger outputWidth ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), outputWidth ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerDescriptor::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setOutputHeight( NS::UInteger outputHeight ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), outputHeight ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::requiresSynchronousInitialization() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( requiresSynchronousInitialization ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setRequiresSynchronousInitialization( bool requiresSynchronousInitialization ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRequiresSynchronousInitialization_ ), requiresSynchronousInitialization ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isAutoExposureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isAutoExposureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setAutoExposureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setAutoExposureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isInputContentPropertiesEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isInputContentPropertiesEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputContentPropertiesEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentPropertiesEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::inputContentMinScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputContentMinScale( float inputContentMinScale ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMinScale_ ), inputContentMinScale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::inputContentMaxScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setInputContentMaxScale( float inputContentMaxScale ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMaxScale_ ), inputContentMaxScale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isReactiveMaskTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isReactiveMaskTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setReactiveMaskTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerDescriptor::reactiveMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isSpecularHitDistanceTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isSpecularHitDistanceTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setSpecularHitDistanceTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularHitDistanceTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isDenoiseStrengthMaskTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDenoiseStrengthMaskTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setDenoiseStrengthMaskTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDenoiseStrengthMaskTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::isTransparencyOverlayTextureEnabled() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isTransparencyOverlayTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerDescriptor::setTransparencyOverlayTextureEnabled( bool enabled ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setTransparencyOverlayTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalDenoisedScaler* MTLFX::TemporalDenoisedScalerDescriptor::newTemporalDenoisedScaler( const MTL::Device* device ) const +{ + return NS::Object::sendMessage< TemporalDenoisedScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalDenoisedScalerWithDevice_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::TemporalDenoisedScaler* MTLFX::TemporalDenoisedScalerDescriptor::newTemporalDenoisedScaler( const MTL::Device* device, const MTL4::Compiler* compiler ) const +{ + return NS::Object::sendMessage< MTL4FX::TemporalDenoisedScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalDenoisedScalerWithDevice_compiler_ ), device, compiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::supportedInputContentMinScale( MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _NS_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ), pDevice ); + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerDescriptor::supportedInputContentMaxScale( MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _MTLFX_PRIVATE_CLS( MTLFXTemporalDenoisedScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ), pDevice ); + } + else if ( supportsDevice( pDevice ) ) + { + scale = 2.0f; + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::supportsMetal4FX( MTL::Device* device ) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXTemporalDenoisedScalerDescriptor), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerDescriptor::supportsDevice( MTL::Device* device ) +{ + return NS::Object::sendMessageSafe< bool >( _MTLFX_PRIVATE_CLS(MTLFXTemporalDenoisedScalerDescriptor), _MTLFX_PRIVATE_SEL( supportsDevice_ ), device ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::colorTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::depthTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( depthTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::motionTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( motionTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::reactiveTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( reactiveTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::diffuseAlbedoTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::specularAlbedoTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::normalTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( normalTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::roughnessTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( roughnessTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::specularHitDistanceTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::denoiseStrengthMaskTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::transparencyOverlayTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalDenoisedScalerBase::outputTextureUsage() const +{ + return NS::Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::colorTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setColorTexture( MTL::Texture* colorTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), colorTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::depthTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( depthTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDepthTexture( MTL::Texture* depthTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTexture_ ), depthTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::motionTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( motionTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setMotionTexture( MTL::Texture* motionTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTexture_ ), motionTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::diffuseAlbedoTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDiffuseAlbedoTexture( MTL::Texture* diffuseAlbedoTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDiffuseAlbedoTexture_ ), diffuseAlbedoTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::specularAlbedoTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setSpecularAlbedoTexture( MTL::Texture* specularAlbedoTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularAlbedoTexture_ ), specularAlbedoTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::normalTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( normalTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setNormalTexture( MTL::Texture* normalTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setNormalTexture_ ), normalTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::roughnessTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( roughnessTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setRoughnessTexture( MTL::Texture* roughnessTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRoughnessTexture_ ), roughnessTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::specularHitDistanceTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setSpecularHitDistanceTexture( MTL::Texture* specularHitDistanceTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setSpecularHitDistanceTexture_ ), specularHitDistanceTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::denoiseStrengthMaskTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDenoiseStrengthMaskTexture( MTL::Texture* denoiseStrengthMaskTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDenoiseStrengthMaskTexture_ ), denoiseStrengthMaskTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::transparencyOverlayTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setTransparencyOverlayTexture( MTL::Texture* transparencyOverlayTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setTransparencyOverlayTexture_ ), transparencyOverlayTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::outputTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setOutputTexture( MTL::Texture* outputTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), outputTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::exposureTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( exposureTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setExposureTexture( MTL::Texture* exposureTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setExposureTexture_ ), exposureTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::preExposure() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( preExposure ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setPreExposure( float preExposure ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setPreExposure_ ), preExposure ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalDenoisedScalerBase::reactiveMaskTexture() const +{ + return NS::Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTexture_ ), reactiveMaskTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::jitterOffsetX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setJitterOffsetX( float jitterOffsetX ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetX_ ), jitterOffsetX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::jitterOffsetY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setJitterOffsetY( float jitterOffsetY ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetY_ ), jitterOffsetY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::motionVectorScaleX() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setMotionVectorScaleX( float motionVectorScaleX ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleX_ ), motionVectorScaleX ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::motionVectorScaleY() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setMotionVectorScaleY( float motionVectorScaleY ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleY_ ), motionVectorScaleY ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerBase::shouldResetHistory() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( shouldResetHistory ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setShouldResetHistory( bool shouldResetHistory ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setShouldResetHistory_ ), shouldResetHistory ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalDenoisedScalerBase::isDepthReversed() const +{ + return NS::Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDepthReversed ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setDepthReversed( bool depthReversed ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthReversed_ ), depthReversed ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::colorTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::depthTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::motionTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::diffuseAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( diffuseAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::specularAlbedoTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularAlbedoTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::normalTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( normalTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::roughnessTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( roughnessTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::specularHitDistanceTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( specularHitDistanceTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::denoiseStrengthMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( denoiseStrengthMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::transparencyOverlayTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( transparencyOverlayTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::reactiveMaskTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalDenoisedScalerBase::outputTextureFormat() const +{ + return NS::Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::inputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::inputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::outputWidth() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalDenoisedScalerBase::outputHeight() const +{ + return NS::Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::inputContentMinScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalDenoisedScalerBase::inputContentMaxScale() const +{ + return NS::Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE simd::float4x4 MTLFX::TemporalDenoisedScalerBase::worldToViewMatrix() const +{ + return NS::Object::sendMessage< simd::float4x4 >( this, _MTLFX_PRIVATE_SEL( worldToViewMatrix ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setWorldToViewMatrix( simd::float4x4 worldToViewMatrix ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setWorldToViewMatrix_ ), worldToViewMatrix ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE simd::float4x4 MTLFX::TemporalDenoisedScalerBase::viewToClipMatrix() const +{ + return NS::Object::sendMessage< simd::float4x4 >( this, _MTLFX_PRIVATE_SEL( viewToClipMatrix ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setViewToClipMatrix( simd::float4x4 viewToClipMatrix ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setViewToClipMatrix_ ), viewToClipMatrix ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::TemporalDenoisedScalerBase::fence() const +{ + return NS::Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScalerBase::setFence( MTL::Fence* fence ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), fence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalDenoisedScaler::encodeToCommandBuffer( MTL::CommandBuffer* commandBuffer ) +{ + return NS::Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), commandBuffer ); +} diff --git a/dist/include/metal_cpp/MetalFX/MTLFXTemporalScaler.hpp b/dist/include/metal_cpp/MetalFX/MTLFXTemporalScaler.hpp new file mode 100644 index 0000000..c13d424 --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MTLFXTemporalScaler.hpp @@ -0,0 +1,803 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MTLFXTemporalScaler.hpp +// +// Copyright 2020-2025 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXDefines.hpp" +#include "MTLFXPrivate.hpp" + +#include "../Metal/Metal.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace MTL4FX +{ + class TemporalScaler; +} + +namespace MTLFX +{ + class TemporalScalerDescriptor : public NS::Copying< TemporalScalerDescriptor > + { + public: + static class TemporalScalerDescriptor* alloc(); + class TemporalScalerDescriptor* init(); + + MTL::PixelFormat colorTextureFormat() const; + void setColorTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat depthTextureFormat() const; + void setDepthTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat motionTextureFormat() const; + void setMotionTextureFormat( MTL::PixelFormat format ); + + MTL::PixelFormat outputTextureFormat() const; + void setOutputTextureFormat( MTL::PixelFormat format ); + + NS::UInteger inputWidth() const; + void setInputWidth( NS::UInteger width ); + + NS::UInteger inputHeight() const; + void setInputHeight( NS::UInteger height ); + + NS::UInteger outputWidth() const; + void setOutputWidth( NS::UInteger width ); + + NS::UInteger outputHeight() const; + void setOutputHeight( NS::UInteger height ); + + bool isAutoExposureEnabled() const; + void setAutoExposureEnabled( bool enabled ); + + bool isInputContentPropertiesEnabled() const; + void setInputContentPropertiesEnabled( bool enabled ); + + bool requiresSynchronousInitialization() const; + void setRequiresSynchronousInitialization(bool requiresSynchronousInitialization); + + bool isReactiveMaskTextureEnabled() const; + void setReactiveMaskTextureEnabled( bool enabled ); + + MTL::PixelFormat reactiveMaskTextureFormat() const; + void setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ); + + float inputContentMinScale() const; + void setInputContentMinScale( float scale ); + + float inputContentMaxScale() const; + void setInputContentMaxScale( float scale ); + + class TemporalScaler* newTemporalScaler( const MTL::Device* pDevice ) const; + MTL4FX::TemporalScaler* newTemporalScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler) const; + + static float supportedInputContentMinScale( const MTL::Device* pDevice ); + static float supportedInputContentMaxScale( const MTL::Device* pDevice ); + + static bool supportsDevice( const MTL::Device* pDevice ); + static bool supportsMetal4FX( const MTL::Device* pDevice ); + }; + + class FrameInterpolatableScaler : public NS::Copying< FrameInterpolatableScaler > + { + }; + + class TemporalScalerBase : public NS::Referencing< TemporalScaler, FrameInterpolatableScaler > + { + public: + MTL::TextureUsage colorTextureUsage() const; + MTL::TextureUsage depthTextureUsage() const; + MTL::TextureUsage motionTextureUsage() const; + MTL::TextureUsage outputTextureUsage() const; + + NS::UInteger inputContentWidth() const; + void setInputContentWidth( NS::UInteger width ); + + NS::UInteger inputContentHeight() const; + void setInputContentHeight( NS::UInteger height ); + + MTL::Texture* colorTexture() const; + void setColorTexture( MTL::Texture* pTexture ); + + MTL::Texture* depthTexture() const; + void setDepthTexture( MTL::Texture* pTexture ); + + MTL::Texture* motionTexture() const; + void setMotionTexture( MTL::Texture* pTexture ); + + MTL::Texture* outputTexture() const; + void setOutputTexture( MTL::Texture* pTexture ); + + MTL::Texture* exposureTexture() const; + void setExposureTexture( MTL::Texture* pTexture ); + + float preExposure() const; + void setPreExposure( float preExposure ); + + float jitterOffsetX() const; + void setJitterOffsetX( float offset ); + + float jitterOffsetY() const; + void setJitterOffsetY( float offset ); + + float motionVectorScaleX() const; + void setMotionVectorScaleX( float scale ); + + float motionVectorScaleY() const; + void setMotionVectorScaleY( float scale ); + + MTL::Texture* reactiveMaskTexture() const; + void setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ); + + MTL::TextureUsage reactiveTextureUsage() const; + + bool reset() const; + void setReset( bool reset ); + + bool isDepthReversed() const; + void setDepthReversed( bool depthReversed ); + + MTL::PixelFormat colorTextureFormat() const; + MTL::PixelFormat depthTextureFormat() const; + MTL::PixelFormat motionTextureFormat() const; + MTL::PixelFormat reactiveTextureFormat() const; + MTL::PixelFormat outputTextureFormat() const; + NS::UInteger inputWidth() const; + NS::UInteger inputHeight() const; + NS::UInteger outputWidth() const; + NS::UInteger outputHeight() const; + float inputContentMinScale() const; + float inputContentMaxScale() const; + + MTL::Fence* fence() const; + void setFence( MTL::Fence* pFence ); + }; + + class TemporalScaler : public NS::Referencing< TemporalScaler, TemporalScalerBase > + { + public: + void encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ); + }; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalScalerDescriptor* MTLFX::TemporalScalerDescriptor::alloc() +{ + return NS::Object::alloc< TemporalScalerDescriptor >( _MTLFX_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalScalerDescriptor* MTLFX::TemporalScalerDescriptor::init() +{ + return NS::Object::init< TemporalScalerDescriptor >(); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setColorTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::depthTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setDepthTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::motionTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setMotionTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setOutputTextureFormat( MTL::PixelFormat format ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTextureFormat_ ), format ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setOutputWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerDescriptor::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setOutputHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::isAutoExposureEnabled() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isAutoExposureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setAutoExposureEnabled( bool enabled ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setAutoExposureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::isInputContentPropertiesEnabled() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isInputContentPropertiesEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputContentPropertiesEnabled( bool enabled ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentPropertiesEnabled_ ), enabled ); +} + + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::requiresSynchronousInitialization() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( requiresSynchronousInitialization ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setRequiresSynchronousInitialization(bool requiresSynchronousInitialization) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setRequiresSynchronousInitialization_ ), requiresSynchronousInitialization ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::isReactiveMaskTextureEnabled() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isReactiveMaskTextureEnabled ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setReactiveMaskTextureEnabled( bool enabled ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureEnabled_ ), enabled ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerDescriptor::reactiveMaskTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setReactiveMaskTextureFormat( MTL::PixelFormat pixelFormat ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTextureFormat_ ), pixelFormat ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::inputContentMinScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputContentMinScale( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMinScale_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::inputContentMaxScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerDescriptor::setInputContentMaxScale( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentMaxScale_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTLFX::TemporalScaler* MTLFX::TemporalScalerDescriptor::newTemporalScaler( const MTL::Device* pDevice ) const +{ + return Object::sendMessage< TemporalScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalScalerWithDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL4FX::TemporalScaler* MTLFX::TemporalScalerDescriptor::newTemporalScaler( const MTL::Device* pDevice, const MTL4::Compiler* pCompiler ) const +{ + return Object::sendMessage< MTL4FX::TemporalScaler* >( this, _MTLFX_PRIVATE_SEL( newTemporalScalerWithDevice_compiler_ ), pDevice, pCompiler ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::supportedInputContentMinScale( const MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMinScaleForDevice_ ), pDevice ); + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerDescriptor::supportedInputContentMaxScale( const MTL::Device* pDevice ) +{ + float scale = 1.0f; + + if ( nullptr != methodSignatureForSelector( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ) ) ) + { + scale = sendMessage< float >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportedInputContentMaxScaleForDevice_ ), pDevice ); + } + else if ( supportsDevice( pDevice ) ) + { + scale = 2.0f; + } + + return scale; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::supportsDevice( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsDevice_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerDescriptor::supportsMetal4FX( const MTL::Device* pDevice ) +{ + return Object::sendMessageSafe< bool >( _NS_PRIVATE_CLS( MTLFXTemporalScalerDescriptor ), _MTLFX_PRIVATE_SEL( supportsMetal4FX_ ), pDevice ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::colorTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( colorTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::depthTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( depthTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::motionTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( motionTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::outputTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( outputTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputContentWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setInputContentWidth( NS::UInteger width ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentWidth_ ), width ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputContentHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputContentHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setInputContentHeight( NS::UInteger height ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setInputContentHeight_ ), height ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::colorTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( colorTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setColorTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setColorTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::depthTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( depthTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setDepthTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::motionTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( motionTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setMotionTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::outputTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( outputTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setOutputTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setOutputTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::exposureTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( exposureTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setExposureTexture( MTL::Texture* pTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setExposureTexture_ ), pTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::preExposure() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( preExposure ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setPreExposure( float preExposure ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setPreExposure_ ), preExposure ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::jitterOffsetX() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setJitterOffsetX( float offset ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetX_ ), offset ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::jitterOffsetY() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( jitterOffsetY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setJitterOffsetY( float offset ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setJitterOffsetY_ ), offset ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::motionVectorScaleX() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleX ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setMotionVectorScaleX( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleX_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::motionVectorScaleY() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( motionVectorScaleY ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setMotionVectorScaleY( float scale ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setMotionVectorScaleY_ ), scale ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Texture* MTLFX::TemporalScalerBase::reactiveMaskTexture() const +{ + return Object::sendMessage< MTL::Texture* >( this, _MTLFX_PRIVATE_SEL( reactiveMaskTexture ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setReactiveMaskTexture( MTL::Texture* reactiveMaskTexture ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReactiveMaskTexture_ ), reactiveMaskTexture ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::TextureUsage MTLFX::TemporalScalerBase::reactiveTextureUsage() const +{ + return Object::sendMessage< MTL::TextureUsage >( this, _MTLFX_PRIVATE_SEL( reactiveTextureUsage ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerBase::reset() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( reset ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setReset( bool reset ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setReset_ ), reset ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE bool MTLFX::TemporalScalerBase::isDepthReversed() const +{ + return Object::sendMessage< bool >( this, _MTLFX_PRIVATE_SEL( isDepthReversed ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setDepthReversed( bool depthReversed ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setDepthReversed_ ), depthReversed ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::colorTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( colorTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::depthTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( depthTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::motionTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( motionTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::PixelFormat MTLFX::TemporalScalerBase::outputTextureFormat() const +{ + return Object::sendMessage< MTL::PixelFormat >( this, _MTLFX_PRIVATE_SEL( outputTextureFormat ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::inputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( inputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::outputWidth() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputWidth ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE NS::UInteger MTLFX::TemporalScalerBase::outputHeight() const +{ + return Object::sendMessage< NS::UInteger >( this, _MTLFX_PRIVATE_SEL( outputHeight ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::inputContentMinScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMinScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE float MTLFX::TemporalScalerBase::inputContentMaxScale() const +{ + return Object::sendMessage< float >( this, _MTLFX_PRIVATE_SEL( inputContentMaxScale ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE MTL::Fence* MTLFX::TemporalScalerBase::fence() const +{ + return Object::sendMessage< MTL::Fence* >( this, _MTLFX_PRIVATE_SEL( fence ) ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScalerBase::setFence( MTL::Fence* pFence ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( setFence_ ), pFence ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_MTLFX_INLINE void MTLFX::TemporalScaler::encodeToCommandBuffer( MTL::CommandBuffer* pCommandBuffer ) +{ + Object::sendMessage< void >( this, _MTLFX_PRIVATE_SEL( encodeToCommandBuffer_ ), pCommandBuffer ); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/MetalFX/MetalFX.hpp b/dist/include/metal_cpp/MetalFX/MetalFX.hpp new file mode 100644 index 0000000..20e647e --- /dev/null +++ b/dist/include/metal_cpp/MetalFX/MetalFX.hpp @@ -0,0 +1,35 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// MetalFX/MetalFX.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "MTLFXSpatialScaler.hpp" +#include "MTLFXTemporalScaler.hpp" +#include "MTLFXTemporalDenoisedScaler.hpp" +#include "MTLFXFrameInterpolator.hpp" + +#include "MTL4FXSpatialScaler.hpp" +#include "MTL4FXTemporalScaler.hpp" +#include "MTL4FXTemporalDenoisedScaler.hpp" +#include "MTL4FXFrameInterpolator.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/QuartzCore/CADefines.hpp b/dist/include/metal_cpp/QuartzCore/CADefines.hpp new file mode 100644 index 0000000..b0641de --- /dev/null +++ b/dist/include/metal_cpp/QuartzCore/CADefines.hpp @@ -0,0 +1,41 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CADefines.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Foundation/NSDefines.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _CA_EXPORT _NS_EXPORT +#define _CA_EXTERN _NS_EXTERN +#define _CA_INLINE _NS_INLINE +#define _CA_PACKED _NS_PACKED + +#define _CA_CONST(type, name) _NS_CONST(type, name) +#define _CA_ENUM(type, name) _NS_ENUM(type, name) +#define _CA_OPTIONS(type, name) _NS_OPTIONS(type, name) + +#define _CA_VALIDATE_SIZE(ns, name) _NS_VALIDATE_SIZE(ns, name) +#define _CA_VALIDATE_ENUM(ns, name) _NS_VALIDATE_ENUM(ns, name) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/QuartzCore/CAMetalDrawable.hpp b/dist/include/metal_cpp/QuartzCore/CAMetalDrawable.hpp new file mode 100644 index 0000000..0057773 --- /dev/null +++ b/dist/include/metal_cpp/QuartzCore/CAMetalDrawable.hpp @@ -0,0 +1,57 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CAMetalDrawable.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Metal/MTLDrawable.hpp" +#include "../Metal/MTLTexture.hpp" + +#include "CADefines.hpp" +#include "CAPrivate.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +class MetalDrawable : public NS::Referencing +{ +public: + class MetalLayer* layer() const; + MTL::Texture* texture() const; +}; +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CA::MetalLayer* CA::MetalDrawable::layer() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(layer)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::Texture* CA::MetalDrawable::texture() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(texture)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/QuartzCore/CAMetalLayer.hpp b/dist/include/metal_cpp/QuartzCore/CAMetalLayer.hpp new file mode 100644 index 0000000..53f6857 --- /dev/null +++ b/dist/include/metal_cpp/QuartzCore/CAMetalLayer.hpp @@ -0,0 +1,216 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CAMetalDrawable.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "../Metal/MTLPixelFormat.hpp" +#include "../Metal/MTLTexture.hpp" +#include "../Metal/MTLResidencySet.hpp" +#include "../Foundation/NSTypes.hpp" +#include +#include + +#include "CADefines.hpp" +#include "CAMetalDrawable.hpp" +#include "CAPrivate.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ + +class MetalLayer : public NS::Referencing +{ +public: + static class MetalLayer* layer(); + + MTL::Device* device() const; + void setDevice(MTL::Device* device); + + MTL::PixelFormat pixelFormat() const; + void setPixelFormat(MTL::PixelFormat pixelFormat); + + bool framebufferOnly() const; + void setFramebufferOnly(bool framebufferOnly); + + CGSize drawableSize() const; + void setDrawableSize(CGSize drawableSize); + + class MetalDrawable* nextDrawable(); + + NS::UInteger maximumDrawableCount() const; + void setMaximumDrawableCount(NS::UInteger maximumDrawableCount); + + bool displaySyncEnabled() const; + void setDisplaySyncEnabled(bool displaySyncEnabled); + + CGColorSpaceRef colorspace() const; + void setColorspace(CGColorSpaceRef colorspace); + + bool allowsNextDrawableTimeout() const; + void setAllowsNextDrawableTimeout(bool allowsNextDrawableTimeout); + + MTL::ResidencySet* residencySet() const; +}; +} // namespace CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +_CA_INLINE CA::MetalLayer* CA::MetalLayer::layer() +{ + return Object::sendMessage(_CA_PRIVATE_CLS(CAMetalLayer), _CA_PRIVATE_SEL(layer)); +} +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::Device* CA::MetalLayer::device() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(device)); +} +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setDevice(MTL::Device* device) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setDevice_), device); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::PixelFormat CA::MetalLayer::pixelFormat() const +{ + return Object::sendMessage(this, + _CA_PRIVATE_SEL(pixelFormat)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setPixelFormat(MTL::PixelFormat pixelFormat) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setPixelFormat_), + pixelFormat); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE bool CA::MetalLayer::framebufferOnly() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(framebufferOnly)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setFramebufferOnly(bool framebufferOnly) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setFramebufferOnly_), + framebufferOnly); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CGSize CA::MetalLayer::drawableSize() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(drawableSize)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setDrawableSize(CGSize drawableSize) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setDrawableSize_), + drawableSize); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CA::MetalDrawable* CA::MetalLayer::nextDrawable() +{ + return Object::sendMessage(this, + _CA_PRIVATE_SEL(nextDrawable)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE NS::UInteger CA::MetalLayer::maximumDrawableCount() const +{ + return Object::sendMessage(this, + _CA_PRIVATE_SEL(maximumDrawableCount)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setMaximumDrawableCount(NS::UInteger maximumDrawableCount) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setMaximumDrawableCount_), + maximumDrawableCount); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE bool CA::MetalLayer::displaySyncEnabled() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(displaySyncEnabled)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setDisplaySyncEnabled(bool displaySyncEnabled) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setDisplaySyncEnabled_), + displaySyncEnabled); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE CGColorSpaceRef CA::MetalLayer::colorspace() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(colorspace)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setColorspace(CGColorSpaceRef colorspace) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setColorspace_), + colorspace); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE bool CA::MetalLayer::allowsNextDrawableTimeout() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(allowsNextDrawableTimeout)); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE void CA::MetalLayer::setAllowsNextDrawableTimeout(bool allowsNextDrawableTimeout) +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(setAllowsNextDrawableTimeout_), + allowsNextDrawableTimeout); +} + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +_CA_INLINE MTL::ResidencySet* CA::MetalLayer::residencySet() const +{ + return Object::sendMessage(this, _CA_PRIVATE_SEL(residencySet) ); +} diff --git a/dist/include/metal_cpp/QuartzCore/CAPrivate.hpp b/dist/include/metal_cpp/QuartzCore/CAPrivate.hpp new file mode 100644 index 0000000..0b7486a --- /dev/null +++ b/dist/include/metal_cpp/QuartzCore/CAPrivate.hpp @@ -0,0 +1,150 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/CAPrivate.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "CADefines.hpp" + +#include + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#define _CA_PRIVATE_CLS(symbol) (Private::Class::s_k##symbol) +#define _CA_PRIVATE_SEL(accessor) (Private::Selector::s_k##accessor) + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#if defined(CA_PRIVATE_IMPLEMENTATION) + +#ifdef METALCPP_SYMBOL_VISIBILITY_HIDDEN +#define _CA_PRIVATE_VISIBILITY __attribute__((visibility("hidden"))) +#else +#define _CA_PRIVATE_VISIBILITY __attribute__((visibility("default"))) +#endif // METALCPP_SYMBOL_VISIBILITY_HIDDEN + +#define _CA_PRIVATE_IMPORT __attribute__((weak_import)) + +#ifdef __OBJC__ +#define _CA_PRIVATE_OBJC_LOOKUP_CLASS(symbol) ((__bridge void*)objc_lookUpClass(#symbol)) +#define _CA_PRIVATE_OBJC_GET_PROTOCOL(symbol) ((__bridge void*)objc_getProtocol(#symbol)) +#else +#define _CA_PRIVATE_OBJC_LOOKUP_CLASS(symbol) objc_lookUpClass(#symbol) +#define _CA_PRIVATE_OBJC_GET_PROTOCOL(symbol) objc_getProtocol(#symbol) +#endif // __OBJC__ + +#define _CA_PRIVATE_DEF_CLS(symbol) void* s_k##symbol _CA_PRIVATE_VISIBILITY = _CA_PRIVATE_OBJC_LOOKUP_CLASS(symbol) +#define _CA_PRIVATE_DEF_PRO(symbol) void* s_k##symbol _CA_PRIVATE_VISIBILITY = _CA_PRIVATE_OBJC_GET_PROTOCOL(symbol) +#define _CA_PRIVATE_DEF_SEL(accessor, symbol) SEL s_k##accessor _CA_PRIVATE_VISIBILITY = sel_registerName(symbol) +#define _CA_PRIVATE_DEF_STR(type, symbol) \ + _CA_EXTERN type const CA##symbol _CA_PRIVATE_IMPORT; \ + type const CA::symbol = (nullptr != &CA##symbol) ? CA##symbol : nullptr + +#else + +#define _CA_PRIVATE_DEF_CLS(symbol) extern void* s_k##symbol +#define _CA_PRIVATE_DEF_PRO(symbol) extern void* s_k##symbol +#define _CA_PRIVATE_DEF_SEL(accessor, symbol) extern SEL s_k##accessor +#define _CA_PRIVATE_DEF_STR(type, symbol) extern type const CA::symbol + +#endif // CA_PRIVATE_IMPLEMENTATION + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +namespace Private +{ + namespace Class + { + _CA_PRIVATE_DEF_CLS(CAMetalLayer); + } // Class +} // Private +} // CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +namespace Private +{ + namespace Protocol + { + + _CA_PRIVATE_DEF_PRO(CAMetalDrawable); + + } // Protocol +} // Private +} // CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +namespace CA +{ +namespace Private +{ + namespace Selector + { + _CA_PRIVATE_DEF_SEL(allowsNextDrawableTimeout, + "allowsNextDrawableTimeout"); + _CA_PRIVATE_DEF_SEL(colorspace, + "colorspace"); + _CA_PRIVATE_DEF_SEL(device, + "device"); + _CA_PRIVATE_DEF_SEL(displaySyncEnabled, + "displaySyncEnabled"); + _CA_PRIVATE_DEF_SEL(drawableSize, + "drawableSize"); + _CA_PRIVATE_DEF_SEL(framebufferOnly, + "framebufferOnly"); + _CA_PRIVATE_DEF_SEL(layer, + "layer"); + _CA_PRIVATE_DEF_SEL(maximumDrawableCount, + "maximumDrawableCount"); + _CA_PRIVATE_DEF_SEL(nextDrawable, + "nextDrawable"); + _CA_PRIVATE_DEF_SEL(pixelFormat, + "pixelFormat"); + _CA_PRIVATE_DEF_SEL(residencySet, + "residencySet"); + _CA_PRIVATE_DEF_SEL(setAllowsNextDrawableTimeout_, + "setAllowsNextDrawableTimeout:"); + _CA_PRIVATE_DEF_SEL(setColorspace_, + "setColorspace:"); + _CA_PRIVATE_DEF_SEL(setDevice_, + "setDevice:"); + _CA_PRIVATE_DEF_SEL(setDisplaySyncEnabled_, + "setDisplaySyncEnabled:"); + _CA_PRIVATE_DEF_SEL(setDrawableSize_, + "setDrawableSize:"); + _CA_PRIVATE_DEF_SEL(setFramebufferOnly_, + "setFramebufferOnly:"); + _CA_PRIVATE_DEF_SEL(setMaximumDrawableCount_, + "setMaximumDrawableCount:"); + _CA_PRIVATE_DEF_SEL(setPixelFormat_, + "setPixelFormat:"); + _CA_PRIVATE_DEF_SEL(texture, + "texture"); + } // Class +} // Private +} // CA + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/QuartzCore/QuartzCore.hpp b/dist/include/metal_cpp/QuartzCore/QuartzCore.hpp new file mode 100644 index 0000000..681003a --- /dev/null +++ b/dist/include/metal_cpp/QuartzCore/QuartzCore.hpp @@ -0,0 +1,28 @@ +//------------------------------------------------------------------------------------------------------------------------------------------------------------- +// +// QuartzCore/QuartzCore.hpp +// +// Copyright 2020-2024 Apple Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#pragma once + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- + +#include "CAMetalDrawable.hpp" +#include "CAMetalLayer.hpp" + +//------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/metal_cpp/README.md b/dist/include/metal_cpp/README.md new file mode 100644 index 0000000..52ae7b5 --- /dev/null +++ b/dist/include/metal_cpp/README.md @@ -0,0 +1,313 @@ +## About + +**metal-cpp** is a low overhead and header only C++ interface for Metal that helps developers add Metal functionality to graphics applications that are written in C++ (such as game engines). **metal-cpp** removes the need to create a shim and allows developers to call Metal functions directly from anywhere in their existing C++ code. + + +## Highlights + +- Drop in C++ alternative interface to the Metal Objective-C headers. +- Direct mapping of all Metal Objective-C classes, constants and enums to C++ in the MTL C++ namespace. +- No measurable overhead compared to calling Metal Objective-C headers, due to inlining of C++ function calls. +- No usage of wrapper containers that require additional allocations. +- Requires C++17 due to the usage of `constexpr` in `NS::Object`. +- Identical header files and function/constant/enum availability for iOS, macOS and tvOS. +- Backwards compatibility: All `bool MTL::Device::supports...()` functions check if their required selectors exist and automatically return `false` if not. +- String (`ErrorDomain`) constants are weak linked and automatically set to `nullptr` if not available. + +## Changelog + +| Version | Changes | +|-|-| +| macOS 26, iOS 26 | Add all the Metal APIs in macOS 26, iOS 26, including support for the **Apple10** GPU family.
Add support for Metal 4 and new denoiser and temporal scalers in MetalFX.| +| macOS 15, iOS 18 | Add all the Metal APIs in macOS 15 and iOS 18. | +| macOS 14, iOS 17 | Add support for the **MetalFX** framework.
Add all the APIs in macOS 14 and iOS 17. | +| macOS 13.3, iOS 16.4 | Add all the APIs in macOS 13.3 and iOS 16.4. | +| macOS 13, iOS 16| Add all the APIs in macOS 13 and iOS 16.
newArray optional `NS::SharedPtr` type to assist with memory management.
newArray convenience function to create a `CA::MetalLayer`.
newArray `MTLSTR(str)` macro allows faster string creation from literals.
Fix a problem with the signature of functions that take an array of pointers as input.
Fix a problem with the signature of the `setGroups()` function in `MTL::LinkedFunctions`.| +| macOS 12, iOS 15 | Initial release. | + +## Memory Allocation Policy + +**metal-cpp** follows the object allocation policies of Cocoa, Cocoa Touch, and CoreFoundation. Understanding these rules is especially important when using metal-cpp, as C++ objects are not eligible for automatic reference counting (ARC). + +**metal-cpp** objects are reference counted. To help convey and manage object lifecycles, the following conventions are observed: + +1. *You own any object returned by methods whose name begins with* `alloc` *,* `new` *,* `copy` *,* `mutableCopy` *, or* `Create`. The method returns these objects with `retainCount` equals to `1`. +2. *You can take ownership of an object by calling its* ```retain()``` *method*. A received object is normally guaranteed to remain valid within the method it was received in. You use `retain` in two situations: (1) In the implementation of an accessor method (a setter) or to take ownership of an object; and (2) To prevent an object from being deallocated as a side-effect of some other operation. +3. *When you no longer need it, you must relinquish ownership of an object you own*. You relinquish ownership by calling its `release()` or `autorelease()` method. +4. *You must not relinquish ownership of an object you do not own*. + +When an object's `retainCount` reaches `0`, the object is immediately deallocated. It is illegal to call methods on a deallocated object and it may lead to an application crash. + +### AutoreleasePools and Objects + +Several methods that create temporary objects in **metal-cpp** add them to an `AutoreleasePool` to help manage their lifetimes. In these situations, after **metal-cpp** creates the object, it adds it to an `AutoreleasePool`, which will release its objects when you release (or drain) it. + +By adding temporary objects to an AutoreleasePool, you do not need to explicitly call `release()` to deallocate them. Instead, you can rely on the `AutoreleasePool` to implicitly manage those lifetimes. + +If you create an object with a method that does not begin with `alloc`, `new`, `copy`, `mutableCopy`, or `Create`, the creating method adds the object to an autorelease pool. + +The typical scope of an `AutoreleasePool` is one frame of rendering for the main thread of the program. When the thread returns control to the RunLoop (an object responsible for receiving input and events from the windowing system), the pool is *drained*, releasing its objects. + +You can create and manage additional `AutoreleasePool`s at smaller scopes to reduce your program's working set, and you are required to do so for any additional threads your program creates. + +If an object's lifecycle needs to be extended beyond the scope of an `AutoreleasePool` instance, you can claim ownership of it by calling its `retain()` method before the pool is drained. In these cases, you are responsible for making the appropriate `release()` call on the object after you no longer need it. + +You can find a more-detailed introduction to the memory management rules here: https://developer.apple.com/library/archive/documentation/Cocoa/Conceptual/MemoryMgmt/Articles/mmRules.html, and here: https://developer.apple.com/library/archive/documentation/CoreFoundation/Conceptual/CFMemoryMgmt/Concepts/Ownership.html + +For more details about the application's RunLoop, please find its documentation here: https://developer.apple.com/documentation/foundation/nsrunloop + +### Use and debug AutoreleasePools + +When you create an autoreleased object and there is no enclosing `AutoreleasePool`, the object is leaked. + +To prevent this, you normally create an `AutoreleasePool` in your program's `main` function, and in the entry function for every thread you create. You may also create additional `AutoreleasePool`s to avoid growing your program's high memory watermark when you create several autoreleased objects, such as when rendering. + +Use the Environment Variable `OBJC_DEBUG_MISSING_POOLS=YES` to print a runtime warning when an autoreleased object is leaked because no enclosing `AutoreleasePool` is available for its thread. + +You can also run `leaks --autoreleasePools` on a memgraph file or a process ID (macOS only) to view a listing of your program's `AutoreleasePool`s and all objects they contain. + +### NS::SharedPtr + +The **metal-cpp** headers include an optional `NS::SharedPtr<>` (shared pointer) template that can help you manually manage memory in your apps. + +Shared pointers in **metal-cpp** are different from `std::shared_ptr<>` in that they implement specific optimizations for its memory model. For example, **metal-cpp**'s shared pointers avoid the overhead of the standard library's version by leveraging the reference counting implementation of the `NS::Object` type. + +#### Note + +The **metal-cpp** shared pointer’s destructor method always calls the `release()` method of the pointer that it wraps. + +You can create an `NS::SharedPtr<>` by calling the metal-cpp's factory method that's appropriate for your application's intent: + +* You can **transfer** ownership of a pointer to a new shared pointer instance by calling the `NS::TransferPtr()` factory function, which is the correct function for Resource Acquisition is Initialization (RAII) implementations because it doesn't increase the pointee's retain count. + +* You can **share** ownership of a pointer with another entity by calling the `NS::RetainPtr()` factory function. This function can also extend an object's lifecycle beyond an `AutoreleasePool` instance's scope because it creates a strong reference to the pointee and increases its retain count. + +Usage of `NS::SharedPtr<>` is optional. + +### nullptr + +Similar to Objective-C, it is legal to call any method, including `retain()` and `release()`, on `nullptr` "objects". While calling methods on `nullptr` still does incur in function call overhead, the effective result is equivalent of a NOP. + +Conversely, do not assume that because calling a method on a pointer did not result in a crash, that the pointed-to object is valid. + +## Adding metal-cpp to a Project + +Simply include `Metal/Metal.hpp`. To ensure that the selector and class symbols are linked, add to one of your cpp files: + +```cpp +#define NS_PRIVATE_IMPLEMENTATION +#define MTL_PRIVATE_IMPLEMENTATION + +#include "Metal/Metal.hpp" +``` + +If you want to use the QuartzCore wrapper, add: + +```cpp +#define CA_PRIVATE_IMPLEMENTATION + +#include "QuartzCore/QuartzCore.hpp" +``` + +## Generating a Single Header File + +Purely optional: You can generate a single header file that contains all **metal-cpp** headers via: + +```shell +./SingleHeader/MakeSingleHeader.py Foundation/Foundation.hpp QuartzCore/QuartzCore.hpp Metal/Metal.hpp MetalFX/MetalFX.hpp +``` + +By default the generator script writes its output to `./SingleHeader/Metal.hpp`. Use the `-o` option to customize output filename. + +## Global Symbol Visibility + +metal-cpp marks all its symbols with `default` visibility. Define the macro: `METALCPP_SYMBOL_VISIBILITY_HIDDEN` to override this behavior and hide its symbols. + +## Examples + +#### Creating the device + +###### Objective-C (with automatic reference counting) + +```objc +id< MTLDevice > device = MTLCreateSystemDefaultDevice(); + +// ... +``` + +###### Objective-C + +```objc +id< MTLDevice > device = MTLCreateSystemDefaultDevice(); + +// ... + +[device release]; +``` + +###### C++ + +```cpp +MTL::Device* pDevice = MTL::CreateSystemDefaultDevice(); + +// ... + +pDevice->release(); +``` + +###### C++ (using NS::SharedPtr) + +```cpp +NS::SharedPtr< MTL::Device > pDevice = NS::TransferPtr( MTL::CreateSystemDefaultDevice() ); + +// ... +``` + +#### Metal function calls map directly to C++ + +###### Objective-C (with automatic reference counting) + +```objc +MTLSamplerDescriptor* samplerDescriptor = [[MTLSamplerDescriptor alloc] init]; + +[samplerDescriptor setSAddressMode: MTLSamplerAddressModeRepeat]; +[samplerDescriptor setTAddressMode: MTLSamplerAddressModeRepeat]; +[samplerDescriptor setRAddressMode: MTLSamplerAddressModeRepeat]; +[samplerDescriptor setMagFilter: MTLSamplerMinMagFilterLinear]; +[samplerDescriptor setMinFilter: MTLSamplerMinMagFilterLinear]; +[samplerDescriptor setMipFilter: MTLSamplerMipFilterLinear]; +[samplerDescriptor setSupportArgumentBuffers: YES]; + +id< MTLSamplerState > samplerState = [device newSamplerStateWithDescriptor:samplerDescriptor]; +``` + +###### Objective-C + +```objc +MTLSamplerDescriptor* samplerDescriptor = [[MTLSamplerDescriptor alloc] init]; + +[samplerDescriptor setSAddressMode: MTLSamplerAddressModeRepeat]; +[samplerDescriptor setTAddressMode: MTLSamplerAddressModeRepeat]; +[samplerDescriptor setRAddressMode: MTLSamplerAddressModeRepeat]; +[samplerDescriptor setMagFilter: MTLSamplerMinMagFilterLinear]; +[samplerDescriptor setMinFilter: MTLSamplerMinMagFilterLinear]; +[samplerDescriptor setMipFilter: MTLSamplerMipFilterLinear]; +[samplerDescriptor setSupportArgumentBuffers: YES]; + +id< MTLSamplerState > samplerState = [device newSamplerStateWithDescriptor:samplerDescriptor]; + +[samplerDescriptor release]; + +// ... + +[samplerState release]; +``` + +###### C++ + +```cpp +MTL::SamplerDescriptor* pSamplerDescriptor = MTL::SamplerDescriptor::alloc()->init(); + +pSamplerDescriptor->setSAddressMode( MTL::SamplerAddressModeRepeat ); +pSamplerDescriptor->setTAddressMode( MTL::SamplerAddressModeRepeat ); +pSamplerDescriptor->setRAddressMode( MTL::SamplerAddressModeRepeat ); +pSamplerDescriptor->setMagFilter( MTL::SamplerMinMagFilterLinear ); +pSamplerDescriptor->setMinFilter( MTL::SamplerMinMagFilterLinear ); +pSamplerDescriptor->setMipFilter( MTL::SamplerMipFilterLinear ); +pSamplerDescriptor->setSupportArgumentBuffers( true ); + +MTL::SamplerState* pSamplerState = pDevice->newSamplerState( pSamplerDescriptor ); + +pSamplerDescriptor->release(); + +// ... + +pSamplerState->release(); +``` + +###### C++ (using NS::SharedPtr) + +```cpp +NS::SharedPtr< MTL::SamplerDescriptor > pSamplerDescriptor = NS::TransferPtr( MTL::SamplerDescriptor::alloc()->init() ); + +pSamplerDescriptor->setSAddressMode( MTL::SamplerAddressModeRepeat ); +pSamplerDescriptor->setTAddressMode( MTL::SamplerAddressModeRepeat ); +pSamplerDescriptor->setRAddressMode( MTL::SamplerAddressModeRepeat ); +pSamplerDescriptor->setMagFilter( MTL::SamplerMinMagFilterLinear ); +pSamplerDescriptor->setMinFilter( MTL::SamplerMinMagFilterLinear ); +pSamplerDescriptor->setMipFilter( MTL::SamplerMipFilterLinear ); +pSamplerDescriptor->setSupportArgumentBuffers( true ); + +NS::SharedPtr< MTL::SamplerState > pSamplerState( pDevice->newSamplerState( pSamplerDescriptor ) ); +``` + +#### A subset of bindings for Foundation classes is provided for seamless integration + +###### Objective-C (with automatic reference counting) + +```objc +NSAutoreleasePool* pool = [[NSAutoreleasePool alloc] init]; +NSString* string = [NSString stringWithCString: "Hello World" encoding: NSASCIIStringEncoding]; + +printf( "string = \"%s\"\n", [string cStringUsingEncoding: NSASCIIStringEncoding] ); +``` + +###### Objective-C + +```objc +NSAutoreleasePool* pool = [[NSAutoreleasePool alloc] init]; +NSString* string = [NSString stringWithCString: "Hello World" encoding: NSASCIIStringEncoding]; + +printf( "string = \"%s\"\n", [string cStringUsingEncoding: NSASCIIStringEncoding] ); + +[pool release]; +``` + +###### C++ + +```cpp +NS::AutoreleasePool* pPool = NS::AutoreleasePool::alloc()->init(); +NS::String* pString = NS::String::string( "Hello World", NS::ASCIIStringEncoding ); + +printf( "pString = \"%s\"\n", pString->cString( NS::ASCIIStringEncoding ) ); + +pPool->release(); +``` + +###### C++ (using NS::SharedPtr) + +```cpp +NS::SharedPtr< NS::AutoreleasePool > pPool = NS::TransferPtr( NS::AutoreleasePool::alloc()->init() ); +NS::String* pString = NS::String::string( "Hello World", NS::ASCIIStringEncoding ); + +printf( "pString = \"%s\"\n", pString->cString( NS::ASCIIStringEncoding ) ); +``` + +#### Containers + +Use the CoreFoundation framework to create `NS::Array` and `NS::Dictionary` instances. + +```cpp +MTL::AccelerationStructureTriangleGeometryDescriptor* pGeoDescriptor = MTL::AccelerationStructureTriangleGeometryDescriptor::alloc()->init(); +CFTypeRef descriptors[] = { ( CFTypeRef )( pGeoDescriptor ) }; +NS::Array* pGeoDescriptors = ( NS::Array* )( CFArrayCreate( kCFAllocatorDefault, descriptors, SIZEOF_ARRAY( descriptors), &kCFTypeArrayCallBacks ) ); + +// ... + +pGeoDescriptors->release(); +``` + +Containers, such as `NS::Array` and `NS::Dictionary`, retain the objects they hold and release them when the container is deallocated. + +#### Accessing the Metal Drawable + +```cpp +#import + +// ... + +CA::MetalLayer* pMetalLayer = /* layer associated with the view */; +CA::MetalDrawable* pMetalDrawable = pMetalLayer->nextDrawable(); + +// ... +``` diff --git a/dist/include/metal_cpp/SingleHeader/MakeSingleHeader.py b/dist/include/metal_cpp/SingleHeader/MakeSingleHeader.py new file mode 100644 index 0000000..c8d3715 --- /dev/null +++ b/dist/include/metal_cpp/SingleHeader/MakeSingleHeader.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- +# +# SingleHeader/MakeSingleHeader.py +# +# Copyright 2020-2024 Apple Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +import argparse +import datetime +import logging +import os +import re +import subprocess +import sys + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +class HeaderPrefix( object ): + __template = ( '//\n' + '// {file}\n' + '//\n' + '// {meta_data}\n' + '//\n' + '// Copyright 2020-2024 Apple Inc.\n' + '//\n' + '// Licensed under the Apache License, Version 2.0 (the "License");\n' + '// you may not use this file except in compliance with the License.\n' + '// You may obtain a copy of the License at\n' + '//\n' + '// http://www.apache.org/licenses/LICENSE-2.0\n' + '//\n' + '// Unless required by applicable law or agreed to in writing, software\n' + '// distributed under the License is distributed on an "AS IS" BASIS,\n' + '// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n' + '// See the License for the specific language governing permissions and\n' + '// limitations under the License.\n' + '//\n' + '\n' ) + + __template_commit = 'Autogenerated from commit {commit}.' + __template_date = 'Autogenerated on %B %d, %Y.' + + def __init__( self, file ): + self.__file = file + + def __str__( self ): + return self.__template.format( file = self.__file, meta_data = self.__meta_data_string() ) + + def __get_commit_hash( self ): + git_commit_hash = None + + try: + git_dir = os.path.dirname( os.path.realpath( __file__ ) ) + proc = subprocess.Popen( [ 'git', 'rev-parse', 'HEAD' ], cwd = git_dir, stdout = subprocess.PIPE, stderr = subprocess.PIPE ) + git_commit_hash = proc.stdout.read().decode( 'utf-8', 'replace' ).strip() + except: + logging.error( 'Failed to determine git commit hash!' ) + pass + + return git_commit_hash + + def __get_commit_string( self ): + meta_data = None + git_commit_hash = self.__get_commit_hash() + + if git_commit_hash: + meta_data = self.__template_commit.format( commit = git_commit_hash ) + + return meta_data + + def __get_date_string( self ): + today = datetime.date.today() + + return today.strftime( self.__template_date ) + + def __meta_data_string( self ): + meta_data = self.__get_commit_string() + + if not meta_data: + meta_data = self.__get_date_string() + + return meta_data + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +class SingleHeader( object ): + __pragma_once = '#pragma once\n\n' + + def __init__( self ): + self.__header_paths = list() + + def __str__( self ): + return self.process() + + def append( self, header_path ): + self.__header_paths.append( header_path ) + + def process( self ): + out_header = self.__pragma_once + + self.__included_headers = set() + self.__base_path = list() + + for header_path in self.__header_paths: + out_header += self.__process_header( header_path ) + + return self.__strip_empty_lines( out_header ) + + def __read_header( self, path ): + path = os.path.realpath( path ) + + try: + f = open( path, 'r' ) + except: + raise RuntimeError( 'Failed to open file \"' + path + '\" for read!' ) + + return f.read() + + def __strip_pragma_once( self, header ): + return re.sub( '\\s*#pragma once\s*\\/\\/-*\\n', '', header ) + + def __strip_comments( self, header ): + return re.sub( '^//.*\\n', '', header, flags = re.MULTILINE ) + + def __strip_empty_lines( self, header ): + return re.sub( '\\n\\n+', '\\n\\n', header, flags = re.MULTILINE ) + + def __substitute_include_directive( self, match ): + header_path = match.group( 'HEADER_PATH' ) + + logging.info( '\tSubstituting \"' + header_path + '\"...' ) + + return self.__process_header( os.path.join( self.__base_path[-1], header_path ) ) + + def __process_include_directives( self, header ): + return re.sub( '^\\s*#include\\s\\"(?P\\S*)\\"', self.__substitute_include_directive, header, flags = re.MULTILINE ) + + def __process_foundation_directives( self, header ): + if header.find("#include ") != -1: + logging.info( '\tSubstituting ...' ) + return header.replace("#include ", self.__process_header( os.path.join( self.__base_path[-1], "../Foundation/Foundation.hpp" ) ) ) + return header + + + def __process_header( self, header_path ): + out_header = '' + + header_path = os.path.realpath( header_path ) + + if not header_path in self.__included_headers: + logging.info( 'Processing \"' + header_path + '\"...' ) + + self.__base_path.append( os.path.dirname( header_path ) ) + self.__included_headers.add( header_path ) + + out_header = self.__read_header( header_path ) + out_header = self.__strip_pragma_once( out_header ) + out_header = self.__strip_comments( out_header ) + out_header = self.__process_include_directives( out_header ) + out_header = self.__process_foundation_directives( out_header ) + + self.__base_path.pop() + else: + logging.info( '\tSkipping \"' + header_path + '\"...' ) + + return out_header + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +def create_argument_parser(): + parser = argparse.ArgumentParser() + base_path = os.path.dirname( os.path.realpath( __file__ ) ) + output_path = os.path.join( base_path, 'Metal.hpp' ) + + parser.add_argument( '-o', '--output', dest = 'output_path', metavar = 'PATH', default = output_path, help = 'Output path for the single header file.' ) + parser.add_argument( '-v', '--verbose', action = 'store_true', help = 'Show verbose output.' ) + parser.add_argument( dest = 'header_paths', metavar = 'HEADER_FILE', nargs='+', help = 'Input header file.' ) + + return parser + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +def parse_arguments(): + parser = create_argument_parser() + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel( logging.INFO ) + else: + logging.getLogger().setLevel( logging.ERROR ) + + return args + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +def make_header( args ): + prefix = HeaderPrefix( os.path.basename( args.output_path ) ) + header = SingleHeader() + + for header_path in args.header_paths: + header.append( header_path ) + + return str( prefix ) + str( header ) + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +def make_dir( path ): + try: + if not os.path.exists( path ): + os.makedirs( path ) + except os.error: + pass + except: + raise + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +def write_header( args, content ): + path = os.path.realpath( args.output_path ) + + logging.info( 'Writing \"' + path + '\"...' ) + + make_dir( os.path.dirname( path ) ) + + try: + f = open( path, 'w' ) + except: + raise RuntimeError( 'Failed to open file \"' + path + '\" for write!' ) + + f.write( content ) + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- + +if __name__ == '__main__': + result = -1 + + try: + if sys.getdefaultencoding().lower() == 'ascii': + reload( sys ) + sys.setdefaultencoding( 'utf-8' ) + + args = parse_arguments() + header = make_header( args ) + + write_header( args, header ) + + result = 0 + + except ( KeyboardInterrupt, SystemExit ): + pass + except: + raise + + sys.exit( result ) + +#-------------------------------------------------------------------------------------------------------------------------------------------------------------- diff --git a/dist/include/mlx/3rdparty/pocketfft.h b/dist/include/mlx/3rdparty/pocketfft.h new file mode 100644 index 0000000..03a4589 --- /dev/null +++ b/dist/include/mlx/3rdparty/pocketfft.h @@ -0,0 +1,3581 @@ +/* +This file is part of pocketfft. + +Copyright (C) 2010-2022 Max-Planck-Society +Copyright (C) 2019-2020 Peter Bell + +For the odd-sized DCT-IV transforms: + Copyright (C) 2003, 2007-14 Matteo Frigo + Copyright (C) 2003, 2007-14 Massachusetts Institute of Technology + +Authors: Martin Reinecke, Peter Bell + +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. +* Neither the name of the copyright holder nor the names of its contributors may + be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#ifndef POCKETFFT_HDRONLY_H +#define POCKETFFT_HDRONLY_H + +#ifndef __cplusplus +#error This file is C++ and requires a C++ compiler. +#endif + +#if !(__cplusplus >= 201103L || _MSVC_LANG+0L >= 201103L) +#error This file requires at least C++11 support. +#endif + +#ifndef POCKETFFT_CACHE_SIZE +#define POCKETFFT_CACHE_SIZE 0 +#endif + +#include +#include +#include +#include +#include +#include +#include +#if POCKETFFT_CACHE_SIZE!=0 +#include +#include +#endif + +#ifndef POCKETFFT_NO_MULTITHREADING +#include +#include +#include +#include +#include +#include +#include + +#ifdef POCKETFFT_PTHREADS +# include +#endif +#endif + +#if defined(__GNUC__) +#define POCKETFFT_NOINLINE __attribute__((noinline)) +#define POCKETFFT_RESTRICT __restrict__ +#elif defined(_MSC_VER) +#define POCKETFFT_NOINLINE __declspec(noinline) +#define POCKETFFT_RESTRICT __restrict +#else +#define POCKETFFT_NOINLINE +#define POCKETFFT_RESTRICT +#endif + +namespace pocketfft { + +namespace detail { +using std::size_t; +using std::ptrdiff_t; + +// Always use std:: for functions +template T cos(T) = delete; +template T sin(T) = delete; +template T sqrt(T) = delete; + +using shape_t = std::vector; +using stride_t = std::vector; + +constexpr bool FORWARD = true, + BACKWARD = false; + +// only enable vector support for gcc>=5.0 and clang>=5.0 +#ifndef POCKETFFT_NO_VECTORS +#define POCKETFFT_NO_VECTORS +#if defined(__INTEL_COMPILER) +// do nothing. This is necessary because this compiler also sets __GNUC__. +#elif defined(__clang__) +// AppleClang has their own version numbering +#ifdef __apple_build_version__ +# if (__clang_major__ > 9) || (__clang_major__ == 9 && __clang_minor__ >= 1) +# undef POCKETFFT_NO_VECTORS +# endif +#elif __clang_major__ >= 5 +# undef POCKETFFT_NO_VECTORS +#endif +#elif defined(__GNUC__) +#if __GNUC__>=5 +#undef POCKETFFT_NO_VECTORS +#endif +#endif +#endif + +template struct VLEN { static constexpr size_t val=1; }; + +#ifndef POCKETFFT_NO_VECTORS +#if (defined(__AVX512F__)) +template<> struct VLEN { static constexpr size_t val=16; }; +template<> struct VLEN { static constexpr size_t val=8; }; +#elif (defined(__AVX__)) +template<> struct VLEN { static constexpr size_t val=8; }; +template<> struct VLEN { static constexpr size_t val=4; }; +#elif (defined(__SSE2__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__VSX__)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#elif (defined(__ARM_NEON__) || defined(__ARM_NEON)) +template<> struct VLEN { static constexpr size_t val=4; }; +template<> struct VLEN { static constexpr size_t val=2; }; +#else +#define POCKETFFT_NO_VECTORS +#endif +#endif + +// the __MINGW32__ part in the conditional below works around the problem that +// the standard C++ library on Windows does not provide aligned_alloc() even +// though the MinGW compiler and MSVC may advertise C++17 compliance. +#if (__cplusplus >= 201703L) && (!defined(__MINGW32__)) && (!defined(_MSC_VER)) +inline void *aligned_alloc(size_t align, size_t size) + { + // aligned_alloc() requires that the requested size is a multiple of "align" + void *ptr = ::aligned_alloc(align,(size+align-1)&(~(align-1))); + if (!ptr) throw std::bad_alloc(); + return ptr; + } +inline void aligned_dealloc(void *ptr) + { free(ptr); } +#else // portable emulation +inline void *aligned_alloc(size_t align, size_t size) + { + align = std::max(align, alignof(max_align_t)); + void *ptr = malloc(size+align); + if (!ptr) throw std::bad_alloc(); + void *res = reinterpret_cast + ((reinterpret_cast(ptr) & ~(uintptr_t(align-1))) + uintptr_t(align)); + (reinterpret_cast(res))[-1] = ptr; + return res; + } +inline void aligned_dealloc(void *ptr) + { if (ptr) free((reinterpret_cast(ptr))[-1]); } +#endif + +template class arr + { + private: + T *p; + size_t sz; + +#if defined(POCKETFFT_NO_VECTORS) + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *res = malloc(num*sizeof(T)); + if (!res) throw std::bad_alloc(); + return reinterpret_cast(res); + } + static void dealloc(T *ptr) + { free(ptr); } +#else + static T *ralloc(size_t num) + { + if (num==0) return nullptr; + void *ptr = aligned_alloc(64, num*sizeof(T)); + return static_cast(ptr); + } + static void dealloc(T *ptr) + { aligned_dealloc(ptr); } +#endif + + public: + arr() : p(0), sz(0) {} + arr(size_t n) : p(ralloc(n)), sz(n) {} + arr(arr &&other) + : p(other.p), sz(other.sz) + { other.p=nullptr; other.sz=0; } + ~arr() { dealloc(p); } + + void resize(size_t n) + { + if (n==sz) return; + dealloc(p); + p = ralloc(n); + sz = n; + } + + T &operator[](size_t idx) { return p[idx]; } + const T &operator[](size_t idx) const { return p[idx]; } + + T *data() { return p; } + const T *data() const { return p; } + + size_t size() const { return sz; } + }; + +template struct cmplx { + T r, i; + cmplx() {} + cmplx(T r_, T i_) : r(r_), i(i_) {} + void Set(T r_, T i_) { r=r_; i=i_; } + void Set(T r_) { r=r_; i=T(0); } + cmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator*= (T2 other) + { r*=other; i*=other; return *this; } + templatecmplx &operator*= (const cmplx &other) + { + T tmp = r*other.r - i*other.i; + i = r*other.i + i*other.r; + r = tmp; + return *this; + } + templatecmplx &operator+= (const cmplx &other) + { r+=other.r; i+=other.i; return *this; } + templatecmplx &operator-= (const cmplx &other) + { r-=other.r; i-=other.i; return *this; } + template auto operator* (const T2 &other) const + -> cmplx + { return {r*other, i*other}; } + template auto operator+ (const cmplx &other) const + -> cmplx + { return {r+other.r, i+other.i}; } + template auto operator- (const cmplx &other) const + -> cmplx + { return {r-other.r, i-other.i}; } + template auto operator* (const cmplx &other) const + -> cmplx + { return {r*other.r-i*other.i, r*other.i + i*other.r}; } + template auto special_mul (const cmplx &other) const + -> cmplx + { + using Tres = cmplx; + return fwd ? Tres(r*other.r+i*other.i, i*other.r-r*other.i) + : Tres(r*other.r-i*other.i, r*other.i+i*other.r); + } +}; +template inline void PM(T &a, T &b, T c, T d) + { a=c+d; b=c-d; } +template inline void PMINPLACE(T &a, T &b) + { T t = a; a+=b; b=t-b; } +template inline void MPINPLACE(T &a, T &b) + { T t = a; a-=b; b=t+b; } +template cmplx conj(const cmplx &a) + { return {a.r, -a.i}; } +template void special_mul (const cmplx &v1, const cmplx &v2, cmplx &res) + { + res = fwd ? cmplx(v1.r*v2.r+v1.i*v2.i, v1.i*v2.r-v1.r*v2.i) + : cmplx(v1.r*v2.r-v1.i*v2.i, v1.r*v2.i+v1.i*v2.r); + } + +template void ROT90(cmplx &a) + { auto tmp_=a.r; a.r=-a.i; a.i=tmp_; } +template void ROTX90(cmplx &a) + { auto tmp_= fwd ? -a.r : a.r; a.r = fwd ? a.i : -a.i; a.i=tmp_; } + +// +// twiddle factor section +// +template class sincos_2pibyn + { + private: + using Thigh = typename std::conditional<(sizeof(T)>sizeof(double)), T, double>::type; + size_t N, mask, shift; + arr> v1, v2; + + static cmplx calc(size_t x, size_t n, Thigh ang) + { + x<<=3; + if (x<4*n) // first half + { + if (x<2*n) // first quadrant + { + if (x(std::cos(Thigh(x)*ang), std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), std::cos(Thigh(2*n-x)*ang)); + } + else // second quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), std::sin(Thigh(2*n-x)*ang)); + } + } + else + { + x=8*n-x; + if (x<2*n) // third quadrant + { + if (x(std::cos(Thigh(x)*ang), -std::sin(Thigh(x)*ang)); + return cmplx(std::sin(Thigh(2*n-x)*ang), -std::cos(Thigh(2*n-x)*ang)); + } + else // fourth quadrant + { + x-=2*n; + if (x(-std::sin(Thigh(x)*ang), -std::cos(Thigh(x)*ang)); + return cmplx(-std::cos(Thigh(2*n-x)*ang), -std::sin(Thigh(2*n-x)*ang)); + } + } + } + + public: + POCKETFFT_NOINLINE sincos_2pibyn(size_t n) + : N(n) + { + constexpr auto pi = 3.141592653589793238462643383279502884197L; + Thigh ang = Thigh(0.25L*pi/n); + size_t nval = (n+2)/2; + shift = 1; + while((size_t(1)< operator[](size_t idx) const + { + if (2*idx<=N) + { + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), T(x1.r*x2.i+x1.i*x2.r)); + } + idx = N-idx; + auto x1=v1[idx&mask], x2=v2[idx>>shift]; + return cmplx(T(x1.r*x2.r-x1.i*x2.i), -T(x1.r*x2.i+x1.i*x2.r)); + } + }; + +struct util // hack to avoid duplicate symbols + { + static POCKETFFT_NOINLINE size_t largest_prime_factor (size_t n) + { + size_t res=1; + while ((n&1)==0) + { res=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { res=x; n/=x; } + if (n>1) res=n; + return res; + } + + static POCKETFFT_NOINLINE double cost_guess (size_t n) + { + constexpr double lfp=1.1; // penalty for non-hardcoded larger factors + size_t ni=n; + double result=0.; + while ((n&1)==0) + { result+=2; n>>=1; } + for (size_t x=3; x*x<=n; x+=2) + while ((n%x)==0) + { + result+= (x<=5) ? double(x) : lfp*double(x); // penalize larger prime factors + n/=x; + } + if (n>1) result+=(n<=5) ? double(n) : lfp*double(n); + return result*double(ni); + } + + /* returns the smallest composite of 2, 3, 5, 7 and 11 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_cmplx(size_t n) + { + if (n<=12) return n; + + size_t bestfac=2*n; + for (size_t f11=1; f11n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + /* returns the smallest composite of 2, 3, 5 which is >= n */ + static POCKETFFT_NOINLINE size_t good_size_real(size_t n) + { + if (n<=6) return n; + + size_t bestfac=2*n; + for (size_t f5=1; f5n) + { + if (x>=1; + } + else + return n; + } + } + return bestfac; + } + + static size_t prod(const shape_t &shape) + { + size_t res=1; + for (auto sz: shape) + res*=sz; + return res; + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace) + { + auto ndim = shape.size(); + if (ndim<1) throw std::runtime_error("ndim must be >= 1"); + if ((stride_in.size()!=ndim) || (stride_out.size()!=ndim)) + throw std::runtime_error("stride dimension mismatch"); + if (inplace && (stride_in!=stride_out)) + throw std::runtime_error("stride mismatch"); + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + const shape_t &axes) + { + sanity_check(shape, stride_in, stride_out, inplace); + auto ndim = shape.size(); + shape_t tmp(ndim,0); + for (auto ax : axes) + { + if (ax>=ndim) throw std::invalid_argument("bad axis number"); + if (++tmp[ax]>1) throw std::invalid_argument("axis specified repeatedly"); + } + } + + static POCKETFFT_NOINLINE void sanity_check(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, bool inplace, + size_t axis) + { + sanity_check(shape, stride_in, stride_out, inplace); + if (axis>=shape.size()) throw std::invalid_argument("bad axis number"); + } + +#ifdef POCKETFFT_NO_MULTITHREADING + static size_t thread_count (size_t /*nthreads*/, const shape_t &/*shape*/, + size_t /*axis*/, size_t /*vlen*/) + { return 1; } +#else + static size_t thread_count (size_t nthreads, const shape_t &shape, + size_t axis, size_t vlen) + { + if (nthreads==1) return 1; + size_t size = prod(shape); + size_t parallel = size / (shape[axis] * vlen); + if (shape[axis] < 1000) + parallel /= 4; + size_t max_threads = nthreads == 0 ? + std::thread::hardware_concurrency() : nthreads; + return std::max(size_t(1), std::min(parallel, max_threads)); + } +#endif + }; + +namespace threading { + +#ifdef POCKETFFT_NO_MULTITHREADING + +constexpr inline size_t thread_id() { return 0; } +constexpr inline size_t num_threads() { return 1; } + +template +void thread_map(size_t /* nthreads */, Func f) + { f(); } + +#else + +inline size_t &thread_id() + { + static thread_local size_t thread_id_=0; + return thread_id_; + } +inline size_t &num_threads() + { + static thread_local size_t num_threads_=1; + return num_threads_; + } +static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); + +class latch + { + std::atomic num_left_; + std::mutex mut_; + std::condition_variable completed_; + using lock_t = std::unique_lock; + + public: + latch(size_t n): num_left_(n) {} + + void count_down() + { + lock_t lock(mut_); + if (--num_left_) + return; + completed_.notify_all(); + } + + void wait() + { + lock_t lock(mut_); + completed_.wait(lock, [this]{ return is_ready(); }); + } + bool is_ready() { return num_left_ == 0; } + }; + +template class concurrent_queue + { + std::queue q_; + std::mutex mut_; + std::atomic size_; + using lock_t = std::lock_guard; + + public: + + void push(T val) + { + lock_t lock(mut_); + ++size_; + q_.push(std::move(val)); + } + + bool try_pop(T &val) + { + if (size_ == 0) return false; + lock_t lock(mut_); + // Queue might have been emptied while we acquired the lock + if (q_.empty()) return false; + + val = std::move(q_.front()); + --size_; + q_.pop(); + return true; + } + + bool empty() const { return size_==0; } + }; + +// C++ allocator with support for over-aligned types +template struct aligned_allocator + { + using value_type = T; + template + aligned_allocator(const aligned_allocator&) {} + aligned_allocator() = default; + + T *allocate(size_t n) + { + void* mem = aligned_alloc(alignof(T), n*sizeof(T)); + return static_cast(mem); + } + + void deallocate(T *p, size_t /*n*/) + { aligned_dealloc(p); } + }; + +class thread_pool + { + // A reasonable guess, probably close enough for most hardware + static constexpr size_t cache_line_size = 64; + struct alignas(cache_line_size) worker + { + std::thread thread; + std::condition_variable work_ready; + std::mutex mut; + std::atomic_flag busy_flag = ATOMIC_FLAG_INIT; + std::function work; + + void worker_main( + std::atomic &shutdown_flag, + std::atomic &unscheduled_tasks, + concurrent_queue> &overflow_work) + { + using lock_t = std::unique_lock; + bool expect_work = true; + while (!shutdown_flag || expect_work) + { + std::function local_work; + if (expect_work || unscheduled_tasks == 0) + { + lock_t lock(mut); + // Wait until there is work to be executed + work_ready.wait(lock, [&]{ return (work || shutdown_flag); }); + local_work.swap(work); + expect_work = false; + } + + bool marked_busy = false; + if (local_work) + { + marked_busy = true; + local_work(); + } + + if (!overflow_work.empty()) + { + if (!marked_busy && busy_flag.test_and_set()) + { + expect_work = true; + continue; + } + marked_busy = true; + + while (overflow_work.try_pop(local_work)) + { + --unscheduled_tasks; + local_work(); + } + } + + if (marked_busy) busy_flag.clear(); + } + } + }; + + concurrent_queue> overflow_work_; + std::mutex mut_; + std::vector> workers_; + std::atomic shutdown_; + std::atomic unscheduled_tasks_; + using lock_t = std::lock_guard; + + void create_threads() + { + lock_t lock(mut_); + size_t nthreads=workers_.size(); + for (size_t i=0; ibusy_flag.clear(); + worker->work = nullptr; + worker->thread = std::thread([worker, this] + { + worker->worker_main(shutdown_, unscheduled_tasks_, overflow_work_); + }); + } + catch (...) + { + shutdown_locked(); + throw; + } + } + } + + void shutdown_locked() + { + shutdown_ = true; + for (auto &worker : workers_) + worker.work_ready.notify_all(); + + for (auto &worker : workers_) + if (worker.thread.joinable()) + worker.thread.join(); + } + + public: + explicit thread_pool(size_t nthreads): + workers_(nthreads) + { create_threads(); } + + thread_pool(): thread_pool(max_threads) {} + + ~thread_pool() { shutdown(); } + + void submit(std::function work) + { + lock_t lock(mut_); + if (shutdown_) + throw std::runtime_error("Work item submitted after shutdown"); + + ++unscheduled_tasks_; + + // First check for any idle workers and wake those + for (auto &worker : workers_) + if (!worker.busy_flag.test_and_set()) + { + --unscheduled_tasks_; + { + lock_t lock(worker.mut); + worker.work = std::move(work); + } + worker.work_ready.notify_one(); + return; + } + + // If no workers were idle, push onto the overflow queue for later + overflow_work_.push(std::move(work)); + } + + void shutdown() + { + lock_t lock(mut_); + shutdown_locked(); + } + + void restart() + { + shutdown_ = false; + create_threads(); + } + }; + +inline thread_pool & get_pool() + { + static thread_pool pool; +#ifdef POCKETFFT_PTHREADS + static std::once_flag f; + std::call_once(f, + []{ + pthread_atfork( + +[]{ get_pool().shutdown(); }, // prepare + +[]{ get_pool().restart(); }, // parent + +[]{ get_pool().restart(); } // child + ); + }); +#endif + + return pool; + } + +/** Map a function f over nthreads */ +template +void thread_map(size_t nthreads, Func f) + { + if (nthreads == 0) + nthreads = max_threads; + + if (nthreads == 1) + { f(); return; } + + auto & pool = get_pool(); + latch counter(nthreads); + std::exception_ptr ex; + std::mutex ex_mut; + for (size_t i=0; i lock(ex_mut); + ex = std::current_exception(); + } + counter.count_down(); + }); + } + counter.wait(); + if (ex) + std::rethrow_exception(ex); + } + +#endif + +} + +// +// complex FFTPACK transforms +// + +template class cfftp + { + private: + struct fctdata + { + size_t fct; + cmplx *tw, *tws; + }; + + size_t length; + arr> mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +template void pass2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(CC(i,0,k)-CC(i,1,k),WA(0,i),CH(i,k,1)); + } + } + } + +#define POCKETFFT_PREP3(idx) \ + T t0 = CC(idx,0,k), t1, t2; \ + PM (t1,t2,CC(idx,1,k),CC(idx,2,k)); \ + CH(idx,k,0)=t0+t1; +#define POCKETFFT_PARTSTEP3a(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb) ;\ + } +#define POCKETFFT_PARTSTEP3b(u1,u2,twr,twi) \ + { \ + T ca=t0+t1*twr; \ + T cb{-t2.i*twi, t2.r*twi}; \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass3 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r=-0.5, + tw1i= (fwd ? -1: 1) * T0(0.8660254037844386467637231707529362L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void pass4 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + else + for (size_t k=0; k(t4); + PM(CH(0,k,0),CH(0,k,2),t2,t3); + PM(CH(0,k,1),CH(0,k,3),t1,t4); + } + for (size_t i=1; i(t4); + CH(i,k,0) = t2+t3; + special_mul(t1+t4,WA(0,i),CH(i,k,1)); + special_mul(t2-t3,WA(1,i),CH(i,k,2)); + special_mul(t1-t4,WA(2,i),CH(i,k,3)); + } + } + } + +#define POCKETFFT_PREP5(idx) \ + T t0 = CC(idx,0,k), t1, t2, t3, t4; \ + PM (t1,t4,CC(idx,1,k),CC(idx,4,k)); \ + PM (t2,t3,CC(idx,2,k),CC(idx,3,k)); \ + CH(idx,k,0).r=t0.r+t1.r+t2.r; \ + CH(idx,k,0).i=t0.i+t1.i+t2.i; + +#define POCKETFFT_PARTSTEP5a(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + PM(CH(0,k,u1),CH(0,k,u2),ca,cb); \ + } + +#define POCKETFFT_PARTSTEP5b(u1,u2,twar,twbr,twai,twbi) \ + { \ + T ca,cb,da,db; \ + ca.r=t0.r+twar*t1.r+twbr*t2.r; \ + ca.i=t0.i+twar*t1.i+twbr*t2.i; \ + cb.i=twai*t4.r twbi*t3.r; \ + cb.r=-(twai*t4.i twbi*t3.i); \ + special_mul(ca+cb,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(ca-cb,WA(u2-1,i),CH(i,k,u2)); \ + } +template void pass5 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.3090169943749474241022934171828191L), + tw1i= (fwd ? -1: 1) * T0(0.9510565162951535721164393333793821L), + tw2r= T0(-0.8090169943749474241022934171828191L), + tw2i= (fwd ? -1: 1) * T0(0.5877852522924731291687059546390728L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass7(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.6234898018587335305250048840042398L), + tw1i= (fwd ? -1 : 1) * T0(0.7818314824680298087084445266740578L), + tw2r= T0(-0.2225209339563144042889025644967948L), + tw2i= (fwd ? -1 : 1) * T0(0.9749279121818236070181316829939312L), + tw3r= T0(-0.9009688679024191262361023195074451L), + tw3i= (fwd ? -1 : 1) * T0(0.433883739117558120475768332848359L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+7*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void ROTX45(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.r+a.i); a.i=hsqt2*(a.i-tmp_); } + else + { auto tmp_=a.r; a.r=hsqt2*(a.r-a.i); a.i=hsqt2*(a.i+tmp_); } + } +template void ROTX135(T &a) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + if (fwd) + { auto tmp_=a.r; a.r=hsqt2*(a.i-a.r); a.i=hsqt2*(-tmp_-a.i); } + else + { auto tmp_=a.r; a.r=hsqt2*(-a.r-a.i); a.i=hsqt2*(tmp_-a.i); } + } + +template void pass8 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+8*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + else + for (size_t k=0; k(a3); + + ROTX90(a7); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + + PM(a0,a4,CC(0,0,k),CC(0,4,k)); + PM(a2,a6,CC(0,2,k),CC(0,6,k)); + PM(CH(0,k,0),CH(0,k,4),a0+a2,a1); + PM(CH(0,k,2),CH(0,k,6),a0-a2,a3); + ROTX90(a6); + PM(CH(0,k,1),CH(0,k,5),a4+a6,a5); + PM(CH(0,k,3),CH(0,k,7),a4-a6,a7); + } + for (size_t i=1; i(a7); + PMINPLACE(a1,a3); + ROTX90(a3); + PMINPLACE(a5,a7); + ROTX45(a5); + ROTX135(a7); + PM(a0,a4,CC(i,0,k),CC(i,4,k)); + PM(a2,a6,CC(i,2,k),CC(i,6,k)); + PMINPLACE(a0,a2); + CH(i,k,0) = a0+a1; + special_mul(a0-a1,WA(3,i),CH(i,k,4)); + special_mul(a2+a3,WA(1,i),CH(i,k,2)); + special_mul(a2-a3,WA(5,i),CH(i,k,6)); + ROTX90(a6); + PMINPLACE(a4,a6); + special_mul(a4+a5,WA(0,i),CH(i,k,1)); + special_mul(a4-a5,WA(4,i),CH(i,k,5)); + special_mul(a6+a7,WA(2,i),CH(i,k,3)); + special_mul(a6-a7,WA(6,i),CH(i,k,7)); + } + } + } + + +#define POCKETFFT_PREP11(idx) \ + T t1 = CC(idx,0,k), t2, t3, t4, t5, t6, t7, t8, t9, t10, t11; \ + PM (t2,t11,CC(idx,1,k),CC(idx,10,k)); \ + PM (t3,t10,CC(idx,2,k),CC(idx, 9,k)); \ + PM (t4,t9 ,CC(idx,3,k),CC(idx, 8,k)); \ + PM (t5,t8 ,CC(idx,4,k),CC(idx, 7,k)); \ + PM (t6,t7 ,CC(idx,5,k),CC(idx, 6,k)); \ + CH(idx,k,0).r=t1.r+t2.r+t3.r+t4.r+t5.r+t6.r; \ + CH(idx,k,0).i=t1.i+t2.i+t3.i+t4.i+t5.i+t6.i; + +#define POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,out1,out2) \ + { \ + T ca = t1 + t2*x1 + t3*x2 + t4*x3 + t5*x4 +t6*x5, \ + cb; \ + cb.i=y1*t11.r y2*t10.r y3*t9.r y4*t8.r y5*t7.r; \ + cb.r=-(y1*t11.i y2*t10.i y3*t9.i y4*t8.i y5*t7.i ); \ + PM(out1,out2,ca,cb); \ + } +#define POCKETFFT_PARTSTEP11a(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,CH(0,k,u1),CH(0,k,u2)) +#define POCKETFFT_PARTSTEP11(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5) \ + { \ + T da,db; \ + POCKETFFT_PARTSTEP11a0(u1,u2,x1,x2,x3,x4,x5,y1,y2,y3,y4,y5,da,db) \ + special_mul(da,WA(u1-1,i),CH(i,k,u1)); \ + special_mul(db,WA(u2-1,i),CH(i,k,u2)); \ + } + +template void pass11 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tw1r= T0(0.8412535328311811688618116489193677L), + tw1i= (fwd ? -1 : 1) * T0(0.5406408174555975821076359543186917L), + tw2r= T0(0.4154150130018864255292741492296232L), + tw2i= (fwd ? -1 : 1) * T0(0.9096319953545183714117153830790285L), + tw3r= T0(-0.1423148382732851404437926686163697L), + tw3i= (fwd ? -1 : 1) * T0(0.9898214418809327323760920377767188L), + tw4r= T0(-0.6548607339452850640569250724662936L), + tw4i= (fwd ? -1 : 1) * T0(0.7557495743542582837740358439723444L), + tw5r= T0(-0.9594929736144973898903680570663277L), + tw5i= (fwd ? -1 : 1) * T0(0.2817325568414296977114179153466169L); + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+11*c)]; }; + auto WA = [wa, ido](size_t x, size_t i) + { return wa[i-1+x*(ido-1)]; }; + + if (ido==1) + for (size_t k=0; k void passg (size_t ido, size_t ip, + size_t l1, T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const cmplx * POCKETFFT_RESTRICT wa, + const cmplx * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph = (ip+1)/2; + size_t idl1 = ido*l1; + + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CX = [cc, ido, l1](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto CX2 = [cc, idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch, idl1](size_t a, size_t b) -> const T& + { return ch[a+idl1*b]; }; + + arr> wal(ip); + wal[0] = cmplx(1., 0.); + for (size_t i=1; i(csarr[i].r,fwd ? -csarr[i].i : csarr[i].i); + + for (size_t k=0; kip) iwal-=ip; + cmplx xwal=wal[iwal]; + iwal+=l; if (iwal>ip) iwal-=ip; + cmplx xwal2=wal[iwal]; + for (size_t ik=0; ikip) iwal-=ip; + cmplx xwal=wal[iwal]; + for (size_t ik=0; ik(x1,wa[idij],CX(i,k,j)); + idij=(jc-1)*(ido-1)+i-1; + special_mul(x2,wa[idij],CX(i,k,jc)); + } + } + } + } + +template void pass_all(T c[], T0 fct) const + { + if (length==1) { c[0]*=fct; return; } + size_t l1=1; + arr ch(length); + T *p1=c, *p2=ch.data(); + + for(size_t k1=0; k1 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==8) + pass8(ido, l1, p1, p2, fact[k1].tw); + else if(ip==2) + pass2(ido, l1, p1, p2, fact[k1].tw); + else if(ip==3) + pass3 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==5) + pass5 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==7) + pass7 (ido, l1, p1, p2, fact[k1].tw); + else if(ip==11) + pass11 (ido, l1, p1, p2, fact[k1].tw); + else + { + passg(ido, ip, l1, p1, p2, fact[k1].tw, fact[k1].tws); + std::swap(p1,p2); + } + std::swap(p1,p2); + l1=l2; + } + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool fwd) const + { fwd ? pass_all(c, fct) : pass_all(c, fct); } + + private: + POCKETFFT_NOINLINE void factorize() + { + size_t len=length; + while ((len&7)==0) + { add_factor(8); len>>=3; } + while ((len&3)==0) + { add_factor(4); len>>=2; } + if ((len&1)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsize=0, l1=1; + for (size_t k=0; k11) + twsize+=ip; + l1*=ip; + } + return twsize; + } + + void comp_twiddle() + { + sincos_2pibyn twiddle(length); + size_t l1=1; + size_t memofs=0; + for (size_t k=0; k11) + { + fact[k].tws=mem.data()+memofs; + memofs+=ip; + for (size_t j=0; j class rfftp + { + private: + struct fctdata + { + size_t fct; + T0 *tw, *tws; + }; + + size_t length; + arr mem; + std::vector fact; + + void add_factor(size_t factor) + { fact.push_back({factor, nullptr, nullptr}); } + +/* (a+ib) = conj(c+id) * (e+if) */ +template inline void MULPM + (T1 &a, T1 &b, T2 c, T2 d, T3 e, T3 f) const + { a=c*e+d*f; b=c*f-d*e; } + +template void radf2 (size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+2*c)]; }; + + for (size_t k=0; k void radf3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+3*c)]; }; + + for (size_t k=0; k void radf4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 hsqt2=T0(0.707106781186547524400844362104849L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+4*c)]; }; + + for (size_t k=0; k void radf5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto CH = [ch,ido](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+5*c)]; }; + + for (size_t k=0; k void radfg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1] (size_t a, size_t b, size_t c) -> T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1] (size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1] (size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + if (ido>1) + { + for (size_t j=1, jc=ip-1; j=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if (iang>=ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ik=ip) iang-=ip; + T0 ar=csarr[2*iang], ai=csarr[2*iang+1]; + for (size_t ik=0; ik void radb2(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+2*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb3(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 taur=-0.5, taui=T0(0.8660254037844386467637231707529362L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+3*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb4(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+4*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radb5(size_t ido, size_t l1, + const T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa) const + { + constexpr T0 tr11= T0(0.3090169943749474241022934171828191L), + ti11= T0(0.9510565162951535721164393333793821L), + tr12= T0(-0.8090169943749474241022934171828191L), + ti12= T0(0.5877852522924731291687059546390728L); + + auto WA = [wa,ido](size_t x, size_t i) { return wa[i+x*(ido-1)]; }; + auto CC = [cc,ido](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+5*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + + for (size_t k=0; k void radbg(size_t ido, size_t ip, size_t l1, + T * POCKETFFT_RESTRICT cc, T * POCKETFFT_RESTRICT ch, + const T0 * POCKETFFT_RESTRICT wa, const T0 * POCKETFFT_RESTRICT csarr) const + { + const size_t cdim=ip; + size_t ipph=(ip+1)/ 2; + size_t idl1 = ido*l1; + + auto CC = [cc,ido,cdim](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+cdim*c)]; }; + auto CH = [ch,ido,l1](size_t a, size_t b, size_t c) -> T& + { return ch[a+ido*(b+l1*c)]; }; + auto C1 = [cc,ido,l1](size_t a, size_t b, size_t c) -> const T& + { return cc[a+ido*(b+l1*c)]; }; + auto C2 = [cc,idl1](size_t a, size_t b) -> T& + { return cc[a+idl1*b]; }; + auto CH2 = [ch,idl1](size_t a, size_t b) -> T& + { return ch[a+idl1*b]; }; + + for (size_t k=0; kip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar3=csarr[2*iang], ai3=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar4=csarr[2*iang], ai4=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 ar1=csarr[2*iang], ai1=csarr[2*iang+1]; + iang+=l; if(iang>ip) iang-=ip; + T0 ar2=csarr[2*iang], ai2=csarr[2*iang+1]; + for (size_t ik=0; ikip) iang-=ip; + T0 war=csarr[2*iang], wai=csarr[2*iang+1]; + for (size_t ik=0; ik void copy_and_norm(T *c, T *p1, T0 fct) const + { + if (p1!=c) + { + if (fct!=1.) + for (size_t i=0; i void exec(T c[], T0 fct, bool r2hc) const + { + if (length==1) { c[0]*=fct; return; } + size_t nf=fact.size(); + arr ch(length); + T *p1=c, *p2=ch.data(); + + if (r2hc) + for(size_t k1=0, l1=length; k1>=2; } + if ((len%2)==0) + { + len>>=1; + // factor 2 should be at the front of the factor list + add_factor(2); + std::swap(fact[0].fct, fact.back().fct); + } + for (size_t divisor=3; divisor*divisor<=len; divisor+=2) + while ((len%divisor)==0) + { + add_factor(divisor); + len/=divisor; + } + if (len>1) add_factor(len); + } + + size_t twsize() const + { + size_t twsz=0, l1=1; + for (size_t k=0; k5) twsz+=2*ip; + l1*=ip; + } + return twsz; + } + + void comp_twiddle() + { + sincos_2pibyn twid(length); + size_t l1=1; + T0 *ptr=mem.data(); + for (size_t k=0; k5) // special factors required by *g functions + { + fact[k].tws=ptr; ptr+=2*ip; + fact[k].tws[0] = 1.; + fact[k].tws[1] = 0.; + for (size_t i=2, ic=2*ip-2; i<=ic; i+=2, ic-=2) + { + fact[k].tws[i ] = twid[i/2*(length/ip)].r; + fact[k].tws[i+1] = twid[i/2*(length/ip)].i; + fact[k].tws[ic] = twid[i/2*(length/ip)].r; + fact[k].tws[ic+1] = -twid[i/2*(length/ip)].i; + } + } + l1*=ip; + } + } + + public: + POCKETFFT_NOINLINE rfftp(size_t length_) + : length(length_) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + if (length==1) return; + factorize(); + mem.resize(twsize()); + comp_twiddle(); + } +}; + +// +// complex Bluestein transforms +// + +template class fftblue + { + private: + size_t n, n2; + cfftp plan; + arr> mem; + cmplx *bk, *bkf; + + template void fft(cmplx c[], T0 fct) const + { + arr> akf(n2); + + /* initialize a_k and FFT it */ + for (size_t m=0; m(c[m],bk[m],akf[m]); + auto zero = akf[0]*T0(0); + for (size_t m=n; m(bkf[0]); + for (size_t m=1; m<(n2+1)/2; ++m) + { + akf[m] = akf[m].template special_mul(bkf[m]); + akf[n2-m] = akf[n2-m].template special_mul(bkf[m]); + } + if ((n2&1)==0) + akf[n2/2] = akf[n2/2].template special_mul(bkf[n2/2]); + + /* inverse FFT */ + plan.exec (akf.data(),1.,false); + + /* multiply by b_k */ + for (size_t m=0; m(bk[m])*fct; + } + + public: + POCKETFFT_NOINLINE fftblue(size_t length) + : n(length), n2(util::good_size_cmplx(n*2-1)), plan(n2), mem(n+n2/2+1), + bk(mem.data()), bkf(mem.data()+n) + { + /* initialize b_k */ + sincos_2pibyn tmp(2*n); + bk[0].Set(1, 0); + + size_t coeff=0; + for (size_t m=1; m=2*n) coeff-=2*n; + bk[m] = tmp[coeff]; + } + + /* initialize the zero-padded, Fourier transformed b_k. Add normalisation. */ + arr> tbkf(n2); + T0 xn2 = T0(1)/T0(n2); + tbkf[0] = bk[0]*xn2; + for (size_t m=1; m void exec(cmplx c[], T0 fct, bool fwd) const + { fwd ? fft(c,fct) : fft(c,fct); } + + template void exec_r(T c[], T0 fct, bool fwd) + { + arr> tmp(n); + if (fwd) + { + auto zero = T0(0)*c[0]; + for (size_t m=0; m(tmp.data(),fct); + c[0] = tmp[0].r; + std::copy_n (&tmp[1].r, n-1, &c[1]); + } + else + { + tmp[0].Set(c[0],c[0]*0); + std::copy_n (c+1, n-1, &tmp[1].r); + if ((n&1)==0) tmp[n/2].i=T0(0)*c[0]; + for (size_t m=1; 2*m(tmp.data(),fct); + for (size_t m=0; m class pocketfft_c + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_c(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new cfftp(length)); + return; + } + double comp1 = util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new cfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(cmplx c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec(c,fct,fwd); } + + size_t length() const { return len; } + }; + +// +// flexible (FFTPACK/Bluestein) real-valued 1D transform +// + +template class pocketfft_r + { + private: + std::unique_ptr> packplan; + std::unique_ptr> blueplan; + size_t len; + + public: + POCKETFFT_NOINLINE pocketfft_r(size_t length) + : len(length) + { + if (length==0) throw std::runtime_error("zero-length FFT requested"); + size_t tmp = (length<50) ? 0 : util::largest_prime_factor(length); + if (tmp*tmp <= length) + { + packplan=std::unique_ptr>(new rfftp(length)); + return; + } + double comp1 = 0.5*util::cost_guess(length); + double comp2 = 2*util::cost_guess(util::good_size_cmplx(2*length-1)); + comp2*=1.5; /* fudge factor that appears to give good overall performance */ + if (comp2>(new fftblue(length)); + else + packplan=std::unique_ptr>(new rfftp(length)); + } + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool fwd) const + { packplan ? packplan->exec(c,fct,fwd) : blueplan->exec_r(c,fct,fwd); } + + size_t length() const { return len; } + }; + + +// +// sine/cosine transforms +// + +template class T_dct1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dct1(size_t length) + : fftplan(2*(length-1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int /*type*/, bool /*cosine*/) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=fftplan.length(), n=N/2+1; + if (ortho) + { c[0]*=sqrt2; c[n-1]*=sqrt2; } + arr tmp(N); + tmp[0] = c[0]; + for (size_t i=1; i class T_dst1 + { + private: + pocketfft_r fftplan; + + public: + POCKETFFT_NOINLINE T_dst1(size_t length) + : fftplan(2*(length+1)) {} + + template POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool /*cosine*/) const + { + size_t N=fftplan.length(), n=N/2-1; + arr tmp(N); + tmp[0] = tmp[n+1] = c[0]*0; + for (size_t i=0; i class T_dcst23 + { + private: + pocketfft_r fftplan; + std::vector twiddle; + + public: + POCKETFFT_NOINLINE T_dcst23(size_t length) + : fftplan(length), twiddle(length) + { + sincos_2pibyn tw(4*length); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, bool ortho, + int type, bool cosine) const + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + size_t N=length(); + size_t NS2 = (N+1)/2; + if (type==2) + { + if (!cosine) + for (size_t k=1; k class T_dcst4 + { + private: + size_t N; + std::unique_ptr> fft; + std::unique_ptr> rfft; + arr> C2; + + public: + POCKETFFT_NOINLINE T_dcst4(size_t length) + : N(length), + fft((N&1) ? nullptr : new pocketfft_c(N/2)), + rfft((N&1)? new pocketfft_r(N) : nullptr), + C2((N&1) ? 0 : N/2) + { + if ((N&1)==0) + { + sincos_2pibyn tw(16*N); + for (size_t i=0; i POCKETFFT_NOINLINE void exec(T c[], T0 fct, + bool /*ortho*/, int /*type*/, bool cosine) const + { + size_t n2 = N/2; + if (!cosine) + for (size_t k=0, kc=N-1; k y(N); + { + size_t i=0, m=n2; + for (; mexec(y.data(), fct, true); + { + auto SGN = [](size_t i) + { + constexpr T0 sqrt2=T0(1.414213562373095048801688724209698L); + return (i&2) ? -sqrt2 : sqrt2; + }; + c[n2] = y[0]*SGN(n2+1); + size_t i=0, i1=1, k=1; + for (; k> y(n2); + for(size_t i=0; iexec(y.data(), fct, true); + for(size_t i=0, ic=n2-1; i std::shared_ptr get_plan(size_t length) + { +#if POCKETFFT_CACHE_SIZE==0 + return std::make_shared(length); +#else + constexpr size_t nmax=POCKETFFT_CACHE_SIZE; + static std::array, nmax> cache; + static std::array last_access{{0}}; + static size_t access_counter = 0; + static std::mutex mut; + + auto find_in_cache = [&]() -> std::shared_ptr + { + for (size_t i=0; ilength()==length)) + { + // no need to update if this is already the most recent entry + if (last_access[i]!=access_counter) + { + last_access[i] = ++access_counter; + // Guard against overflow + if (access_counter == 0) + last_access.fill(0); + } + return cache[i]; + } + + return nullptr; + }; + + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + } + auto plan = std::make_shared(length); + { + std::lock_guard lock(mut); + auto p = find_in_cache(); + if (p) return p; + + size_t lru = 0; + for (size_t i=1; i class cndarr: public arr_info + { + protected: + const char *d; + + public: + cndarr(const void *data_, const shape_t &shape_, const stride_t &stride_) + : arr_info(shape_, stride_), + d(reinterpret_cast(data_)) {} + const T &operator[](ptrdiff_t ofs) const + { return *reinterpret_cast(d+ofs); } + }; + +template class ndarr: public cndarr + { + public: + ndarr(void *data_, const shape_t &shape_, const stride_t &stride_) + : cndarr::cndarr(const_cast(data_), shape_, stride_) + {} + T &operator[](ptrdiff_t ofs) + { return *reinterpret_cast(const_cast(cndarr::d+ofs)); } + }; + +template class multi_iter + { + private: + shape_t pos; + const arr_info &iarr, &oarr; + ptrdiff_t p_ii, p_i[N], str_i, p_oi, p_o[N], str_o; + size_t idim, rem; + + void advance_i() + { + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + if (i==idim) continue; + p_ii += iarr.stride(i); + p_oi += oarr.stride(i); + if (++pos[i] < iarr.shape(i)) + return; + pos[i] = 0; + p_ii -= ptrdiff_t(iarr.shape(i))*iarr.stride(i); + p_oi -= ptrdiff_t(oarr.shape(i))*oarr.stride(i); + } + } + + public: + multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) + : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), + str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), + idim(idim_), rem(iarr.size()/iarr.shape(idim)) + { + auto nshares = threading::num_threads(); + if (nshares==1) return; + if (nshares==0) throw std::runtime_error("can't run with zero threads"); + auto myshare = threading::thread_id(); + if (myshare>=nshares) throw std::runtime_error("impossible share requested"); + size_t nbase = rem/nshares; + size_t additional = rem%nshares; + size_t lo = myshare*nbase + ((myshare=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (++pos[i] < arr.shape(i)) + return; + pos[i] = 0; + p -= ptrdiff_t(arr.shape(i))*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + size_t remaining() const { return rem; } + }; + +class rev_iter + { + private: + shape_t pos; + const arr_info &arr; + std::vector rev_axis; + std::vector rev_jump; + size_t last_axis, last_size; + shape_t shp; + ptrdiff_t p, rp; + size_t rem; + + public: + rev_iter(const arr_info &arr_, const shape_t &axes) + : pos(arr_.ndim(), 0), arr(arr_), rev_axis(arr_.ndim(), 0), + rev_jump(arr_.ndim(), 1), p(0), rp(0) + { + for (auto ax: axes) + rev_axis[ax]=1; + last_axis = axes.back(); + last_size = arr.shape(last_axis)/2 + 1; + shp = arr.shape(); + shp[last_axis] = last_size; + rem=1; + for (auto i: shp) + rem *= i; + } + void advance() + { + --rem; + for (int i_=int(pos.size())-1; i_>=0; --i_) + { + auto i = size_t(i_); + p += arr.stride(i); + if (!rev_axis[i]) + rp += arr.stride(i); + else + { + rp -= arr.stride(i); + if (rev_jump[i]) + { + rp += ptrdiff_t(arr.shape(i))*arr.stride(i); + rev_jump[i] = 0; + } + } + if (++pos[i] < shp[i]) + return; + pos[i] = 0; + p -= ptrdiff_t(shp[i])*arr.stride(i); + if (rev_axis[i]) + { + rp -= ptrdiff_t(arr.shape(i)-shp[i])*arr.stride(i); + rev_jump[i] = 1; + } + else + rp -= ptrdiff_t(shp[i])*arr.stride(i); + } + } + ptrdiff_t ofs() const { return p; } + ptrdiff_t rev_ofs() const { return rp; } + size_t remaining() const { return rem; } + }; + +template struct VTYPE {}; +template using vtype_t = typename VTYPE::type; + +#ifndef POCKETFFT_NO_VECTORS +template<> struct VTYPE + { + using type = float __attribute__ ((vector_size (VLEN::val*sizeof(float)))); + }; +template<> struct VTYPE + { + using type = double __attribute__ ((vector_size (VLEN::val*sizeof(double)))); + }; +template<> struct VTYPE + { + using type = long double __attribute__ ((vector_size (VLEN::val*sizeof(long double)))); + }; +#endif + +template arr alloc_tmp(const shape_t &shape, + size_t axsize, size_t elemsize) + { + auto othersize = util::prod(shape)/axsize; + auto tmpsize = axsize*((othersize>=VLEN::val) ? VLEN::val : 1); + return arr(tmpsize*elemsize); + } +template arr alloc_tmp(const shape_t &shape, + const shape_t &axes, size_t elemsize) + { + size_t fullsize=util::prod(shape); + size_t tmpsize=0; + for (size_t i=0; i=VLEN::val) ? VLEN::val : 1); + if (sz>tmpsize) tmpsize=sz; + } + return arr(tmpsize*elemsize); + } + +template void copy_input(const multi_iter &it, + const cndarr> &src, cmplx> *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, vtype_t *POCKETFFT_RESTRICT dst) + { + for (size_t i=0; i void copy_input(const multi_iter &it, + const cndarr &src, T *POCKETFFT_RESTRICT dst) + { + if (dst == &src[it.iofs(0)]) return; // in-place + for (size_t i=0; i void copy_output(const multi_iter &it, + const cmplx> *POCKETFFT_RESTRICT src, ndarr> &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t i=0; i void copy_output(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + if (src == &dst[it.oofs(0)]) return; // in-place + for (size_t i=0; i struct add_vec { using type = vtype_t; }; +template struct add_vec> + { using type = cmplx>; }; +template using add_vec_t = typename add_vec::type; + +template +POCKETFFT_NOINLINE void general_nd(const cndarr &in, ndarr &out, + const shape_t &axes, T0 fct, size_t nthreads, const Exec & exec, + const bool allow_inplace=true) + { + std::shared_ptr plan; + + for (size_t iax=0; iaxlength())) + plan = get_plan(len); + + threading::thread_map( + util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + const auto &tin(iax==0? in : out); + multi_iter it(tin, out, axes[iax]); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + exec(it, tin, out, tdatav, *plan, fct); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto buf = allow_inplace && it.stride_out() == sizeof(T) ? + &out[it.oofs(0)] : reinterpret_cast(storage.data()); + exec(it, tin, out, buf, *plan, fct); + } + }); // end of parallel region + fct = T0(1); // factor has been applied, use 1 for remaining axes + } + } + +struct ExecC2C + { + bool forward; + + template void operator () ( + const multi_iter &it, const cndarr> &in, + ndarr> &out, T * buf, const pocketfft_c &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, forward); + copy_output(it, buf, out); + } + }; + +template void copy_hartley(const multi_iter &it, + const vtype_t *POCKETFFT_RESTRICT src, ndarr &dst) + { + for (size_t j=0; j void copy_hartley(const multi_iter &it, + const T *POCKETFFT_RESTRICT src, ndarr &dst) + { + dst[it.oofs(0)] = src[0]; + size_t i=1, i1=1, i2=it.length_out()-1; + for (i=1; i void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, + T * buf, const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, true); + copy_hartley(it, buf, out); + } + }; + +struct ExecDcst + { + bool ortho; + int type; + bool cosine; + + template + void operator () (const multi_iter &it, const cndarr &in, + ndarr &out, T * buf, const Tplan &plan, T0 fct) const + { + copy_input(it, in, buf); + plan.exec(buf, fct, ortho, type, cosine); + copy_output(it, buf, out); + } + }; + +template POCKETFFT_NOINLINE void general_r2c( + const cndarr &in, ndarr> &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(in.shape(axis)); + size_t len=in.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(in.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + copy_input(it, in, tdatav); + plan->exec(tdatav, fct, true); + for (size_t j=0; j0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + copy_input(it, in, tdata); + plan->exec(tdata, fct, true); + out[it.oofs(0)].Set(tdata[0]); + size_t i=1, ii=1; + if (forward) + for (; i POCKETFFT_NOINLINE void general_c2r( + const cndarr> &in, ndarr &out, size_t axis, bool forward, T fct, + size_t nthreads) + { + auto plan = get_plan>(out.shape(axis)); + size_t len=out.shape(axis); + threading::thread_map( + util::thread_count(nthreads, in.shape(), axis, VLEN::val), + [&] { + constexpr auto vlen = VLEN::val; + auto storage = alloc_tmp(out.shape(), len, sizeof(T)); + multi_iter it(in, out, axis); +#ifndef POCKETFFT_NO_VECTORS + if (vlen>1) + while (it.remaining()>=vlen) + { + it.advance(vlen); + auto tdatav = reinterpret_cast *>(storage.data()); + for (size_t j=0; jexec(tdatav, fct, false); + copy_output(it, tdatav, out); + } +#endif + while (it.remaining()>0) + { + it.advance(1); + auto tdata = reinterpret_cast(storage.data()); + tdata[0]=in[it.iofs(0)].r; + { + size_t i=1, ii=1; + if (forward) + for (; iexec(tdata, fct, false); + copy_output(it, tdata, out); + } + }); // end of parallel region + } + +struct ExecR2R + { + bool r2h, forward; + + template void operator () ( + const multi_iter &it, const cndarr &in, ndarr &out, T * buf, + const pocketfft_r &plan, T0 fct) const + { + copy_input(it, in, buf); + if ((!r2h) && forward) + for (size_t i=2; i void c2c(const shape_t &shape, const stride_t &stride_in, + const stride_t &stride_out, const shape_t &axes, bool forward, + const std::complex *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr> ain(data_in, shape, stride_in); + ndarr> aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecC2C{forward}); + } + +template void dct(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DCT type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, true}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void dst(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + int type, const T *data_in, T *data_out, T fct, bool ortho, size_t nthreads=1) + { + if ((type<1) || (type>4)) throw std::invalid_argument("invalid DST type"); + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + const ExecDcst exec{ortho, type, false}; + if (type==1) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else if (type==4) + general_nd>(ain, aout, axes, fct, nthreads, exec); + else + general_nd>(ain, aout, axes, fct, nthreads, exec); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axis); + cndarr ain(data_in, shape_in, stride_in); + shape_t shape_out(shape_in); + shape_out[axis] = shape_in[axis]/2 + 1; + ndarr> aout(data_out, shape_out, stride_out); + general_r2c(ain, aout, axis, forward, fct, nthreads); + } + +template void r2c(const shape_t &shape_in, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const T *data_in, std::complex *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_in)==0) return; + util::sanity_check(shape_in, stride_in, stride_out, false, axes); + r2c(shape_in, stride_in, stride_out, axes.back(), forward, data_in, data_out, + fct, nthreads); + if (axes.size()==1) return; + + shape_t shape_out(shape_in); + shape_out[axes.back()] = shape_in[axes.back()]/2 + 1; + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_out, stride_out, stride_out, newaxes, forward, data_out, data_out, + T(1), nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, size_t axis, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + util::sanity_check(shape_out, stride_in, stride_out, false, axis); + shape_t shape_in(shape_out); + shape_in[axis] = shape_out[axis]/2 + 1; + cndarr> ain(data_in, shape_in, stride_in); + ndarr aout(data_out, shape_out, stride_out); + general_c2r(ain, aout, axis, forward, fct, nthreads); + } + +template void c2r(const shape_t &shape_out, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool forward, const std::complex *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape_out)==0) return; + if (axes.size()==1) + return c2r(shape_out, stride_in, stride_out, axes[0], forward, + data_in, data_out, fct, nthreads); + util::sanity_check(shape_out, stride_in, stride_out, false, axes); + auto shape_in = shape_out; + shape_in[axes.back()] = shape_out[axes.back()]/2 + 1; + auto nval = util::prod(shape_in); + stride_t stride_inter(shape_in.size()); + stride_inter.back() = sizeof(cmplx); + for (int i=int(shape_in.size())-2; i>=0; --i) + stride_inter[size_t(i)] = + stride_inter[size_t(i+1)]*ptrdiff_t(shape_in[size_t(i+1)]); + arr> tmp(nval); + auto newaxes = shape_t{axes.begin(), --axes.end()}; + c2c(shape_in, stride_in, stride_inter, newaxes, forward, data_in, tmp.data(), + T(1), nthreads); + c2r(shape_out, stride_inter, stride_out, axes.back(), forward, + tmp.data(), data_out, fct, nthreads); + } + +template void r2r_fftpack(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + bool real2hermitian, bool forward, const T *data_in, T *data_out, T fct, + size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, + ExecR2R{real2hermitian, forward}); + } + +template void r2r_separable_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + cndarr ain(data_in, shape, stride_in); + ndarr aout(data_out, shape, stride_out); + general_nd>(ain, aout, axes, fct, nthreads, ExecHartley{}, + false); + } + +template void r2r_genuine_hartley(const shape_t &shape, + const stride_t &stride_in, const stride_t &stride_out, const shape_t &axes, + const T *data_in, T *data_out, T fct, size_t nthreads=1) + { + if (util::prod(shape)==0) return; + if (axes.size()==1) + return r2r_separable_hartley(shape, stride_in, stride_out, axes, data_in, + data_out, fct, nthreads); + util::sanity_check(shape, stride_in, stride_out, data_in==data_out, axes); + shape_t tshp(shape); + tshp[axes.back()] = tshp[axes.back()]/2+1; + arr> tdata(util::prod(tshp)); + stride_t tstride(shape.size()); + tstride.back()=sizeof(std::complex); + for (size_t i=tstride.size()-1; i>0; --i) + tstride[i-1]=tstride[i]*ptrdiff_t(tshp[i]); + r2c(shape, stride_in, tstride, axes, true, data_in, tdata.data(), fct, nthreads); + cndarr> atmp(tdata.data(), tshp, tstride); + ndarr aout(data_out, shape, stride_out); + simple_iter iin(atmp); + rev_iter iout(aout, axes); + while(iin.remaining()>0) + { + auto v = atmp[iin.ofs()]; + aout[iout.ofs()] = v.r+v.i; + aout[iout.rev_ofs()] = v.r-v.i; + iin.advance(); iout.advance(); + } + } + +} // namespace detail + +using detail::FORWARD; +using detail::BACKWARD; +using detail::shape_t; +using detail::stride_t; +using detail::c2c; +using detail::c2r; +using detail::r2c; +using detail::r2r_fftpack; +using detail::r2r_separable_hartley; +using detail::r2r_genuine_hartley; +using detail::dct; +using detail::dst; + +} // namespace pocketfft + +#undef POCKETFFT_NOINLINE +#undef POCKETFFT_RESTRICT + +#endif // POCKETFFT_HDRONLY_H diff --git a/dist/include/mlx/allocator.h b/dist/include/mlx/allocator.h new file mode 100644 index 0000000..cd6a78e --- /dev/null +++ b/dist/include/mlx/allocator.h @@ -0,0 +1,73 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +namespace mlx::core::allocator { + +// Simple wrapper around buffer pointers +// WARNING: Only Buffer objects constructed from and those that wrap +// raw pointers from mlx::allocator are supported. +class Buffer { + private: + void* ptr_; + + public: + explicit Buffer(void* ptr) : ptr_(ptr) {}; + + // Get the raw data pointer from the buffer + void* raw_ptr(); + + // Get the buffer pointer from the buffer + const void* ptr() const { + return ptr_; + }; + void* ptr() { + return ptr_; + }; +}; + +class Allocator { + /** Abstract base class for a memory allocator. */ + public: + virtual Buffer malloc(size_t size) = 0; + virtual void free(Buffer buffer) = 0; + virtual size_t size(Buffer buffer) const = 0; + virtual Buffer make_buffer(void* ptr, size_t size) { + return Buffer{nullptr}; + }; + virtual void release(Buffer buffer) {} + + Allocator() = default; + Allocator(const Allocator& other) = delete; + Allocator(Allocator&& other) = delete; + Allocator& operator=(const Allocator& other) = delete; + Allocator& operator=(Allocator&& other) = delete; + virtual ~Allocator() = default; +}; + +Allocator& allocator(); + +inline Buffer malloc(size_t size) { + return allocator().malloc(size); +} + +inline void free(Buffer buffer) { + allocator().free(buffer); +} + +// Make a Buffer from a raw pointer of the given size without a copy. If a +// no-copy conversion is not possible then the returned buffer.ptr() will be +// nullptr. Any buffer created with this function must be released with +// release(buffer) +inline Buffer make_buffer(void* ptr, size_t size) { + return allocator().make_buffer(ptr, size); +}; + +// Release a buffer from the allocator made with make_buffer +inline void release(Buffer buffer) { + allocator().release(buffer); +} + +} // namespace mlx::core::allocator diff --git a/dist/include/mlx/array.h b/dist/include/mlx/array.h new file mode 100644 index 0000000..645fa68 --- /dev/null +++ b/dist/include/mlx/array.h @@ -0,0 +1,645 @@ +// Copyright © 2023 Apple Inc. +#pragma once + +#include +#include +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/dtype.h" +#include "mlx/event.h" +#include "mlx/small_vector.h" + +namespace mlx::core { + +// Forward declaration +class Primitive; + +using Deleter = std::function; +using ShapeElem = int32_t; +using Shape = SmallVector; +using Strides = SmallVector; + +class array { + /* An array is really a node in a graph. It contains a shared ArrayDesc + * object */ + + public: + /** Construct a scalar array with zero dimensions. */ + template + explicit array(T val, Dtype dtype = TypeToDtype()); + + /* Special case since std::complex can't be implicitly converted to other + * types. */ + explicit array(const std::complex& val, Dtype dtype = complex64); + + template + explicit array( + It data, + Shape shape, + Dtype dtype = + TypeToDtype::value_type>()); + + template + explicit array(std::initializer_list data, Dtype dtype = TypeToDtype()); + + /* Special case so empty lists default to float32. */ + explicit array(std::initializer_list data); + + /* Special case so array({}, type) is an empty array. */ + explicit array(std::initializer_list data, Dtype dtype); + + template + explicit array( + std::initializer_list data, + Shape shape, + Dtype dtype = TypeToDtype()); + + /* Build an array from a raw pointer. The constructor will attempt to use the + * input data without a copy. The deleter will be called when the array no + * longer needs the underlying memory - after the array is destroyed in the + * no-copy case and after the copy otherwise. */ + explicit array( + void* data, + Shape shape, + Dtype dtype, + const std::function& deleter); + + /* Build an array from a buffer */ + explicit array( + allocator::Buffer data, + Shape shape, + Dtype dtype, + Deleter deleter = allocator::free); + + /** Assignment to rvalue does not compile. */ + array& operator=(const array& other) && = delete; + array& operator=(array&& other) && = delete; + + /** Default copy and move constructors otherwise. */ + array& operator=(array&& other) & = default; + array(const array& other) = default; + array(array&& other) = default; + + array& operator=(const array& other) & { + if (this->id() != other.id()) { + this->array_desc_ = other.array_desc_; + } + return *this; + } + + /** The size of the array's datatype in bytes. */ + size_t itemsize() const { + return size_of(dtype()); + } + + /** The number of elements in the array. */ + size_t size() const { + return array_desc_->size; + } + + /** The number of bytes in the array. */ + size_t nbytes() const { + return size() * itemsize(); + } + + /** The number of dimensions of the array. */ + size_t ndim() const { + return array_desc_->shape.size(); + } + + /** The shape of the array as a vector of integers. */ + const Shape& shape() const { + return array_desc_->shape; + } + + /** + * Get the size of the corresponding dimension. + * + * This function supports negative indexing and provides + * bounds checking. */ + auto shape(int dim) const { + return shape().at(dim < 0 ? dim + ndim() : dim); + } + + /** The strides of the array. */ + const Strides& strides() const { + return array_desc_->strides; + } + + /** + * Get the stride of the corresponding dimension. + * + * This function supports negative indexing and provides + * bounds checking. */ + auto strides(int dim) const { + return strides().at(dim < 0 ? dim + ndim() : dim); + } + + /** Get the arrays data type. */ + Dtype dtype() const { + return array_desc_->dtype; + } + + /** Evaluate the array. */ + void eval(); + + /** Get the value from a scalar array. */ + template + T item(); + + template + T item() const; + + struct ArrayIterator { + using iterator_category = std::random_access_iterator_tag; + using difference_type = size_t; + using value_type = const array; + using reference = value_type; + + explicit ArrayIterator(const array& arr, int idx = 0); + + reference operator*() const; + + ArrayIterator& operator+(difference_type diff) { + idx += diff; + return *this; + } + + ArrayIterator& operator++() { + idx++; + return *this; + } + + friend bool operator==(const ArrayIterator& a, const ArrayIterator& b) { + return a.arr.id() == b.arr.id() && a.idx == b.idx; + } + friend bool operator!=(const ArrayIterator& a, const ArrayIterator& b) { + return !(a == b); + } + + private: + const array& arr; + int idx; + }; + + ArrayIterator begin() const { + return ArrayIterator(*this); + } + ArrayIterator end() const { + return ArrayIterator(*this, shape(0)); + } + + /** + * The following methods should be used with caution. + * They are intended for use by the backend implementation and the + * API may change. + */ + + array( + Shape shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector inputs); + + static std::vector make_arrays( + std::vector shapes, + const std::vector& dtypes, + const std::shared_ptr& primitive, + const std::vector& inputs); + + /** + * Get a new array that refers to the same data as the input but with a + * non-owning pointer to it. Note the array is detached from the graph and has + * no inputs, siblings or primitive. + */ + static array unsafe_weak_copy(const array& other); + + /** A unique identifier for an array. */ + std::uintptr_t id() const { + return reinterpret_cast(array_desc_.get()); + } + + /** A unique identifier for an arrays primitive. */ + std::uintptr_t primitive_id() const { + return reinterpret_cast(array_desc_->primitive.get()); + } + + struct Data { + allocator::Buffer buffer; + Deleter d; + Data(allocator::Buffer buffer, Deleter d = allocator::free) + : buffer(buffer), d(d) {} + // Not copyable + Data(const Data& d) = delete; + Data& operator=(const Data& d) = delete; + Data(Data&& o) : buffer(o.buffer), d(o.d) { + o.buffer = allocator::Buffer(nullptr); + o.d = [](allocator::Buffer) {}; + } + ~Data() { + d(buffer); + } + }; + + struct Flags { + // True iff there are no gaps in the underlying data. Each item + // in the underlying data buffer belongs to at least one index. + // + // True iff: + // prod(shape[i] for i in range(ndim) if strides[i] > 0) == data_size() + bool contiguous : 1; + + // True iff: + // strides[-1] == 1 and + // all(strides[i] == (shape[i+1]*strides[i+1]) or shape[i] == 1 for i in + // range(ndim - 1)) + bool row_contiguous : 1; + + // True iff: + // strides[0] == 1 and + // all(strides[i] == (shape[i-1]*strides[i-1]) or shape[i] == 1 for i in + // range(1, ndim)) + bool col_contiguous : 1; + }; + + /** The array's primitive. */ + Primitive& primitive() const { + return *(array_desc_->primitive); + } + + /** A shared pointer to the array's primitive. */ + std::shared_ptr& primitive_ptr() const { + return array_desc_->primitive; + } + + /** Check if the array has an attached primitive or is a leaf node. */ + bool has_primitive() const { + return array_desc_->primitive != nullptr; + } + + /** The array's inputs. */ + const std::vector& inputs() const { + return array_desc_->inputs; + } + + std::vector& inputs() { + return array_desc_->inputs; + } + + /** True indicates the arrays buffer is safe to reuse */ + bool is_donatable() const { + return array_desc_.use_count() == 1 && (array_desc_->data.use_count() == 1); + } + + /** The array's siblings. */ + const std::vector& siblings() const { + return array_desc_->siblings; + } + + /** The array's siblings. */ + std::vector& siblings() { + return array_desc_->siblings; + } + + /** The array's position in the sibling list. */ + int sibling_position() const { + return array_desc_->position; + } + + void set_siblings(std::vector siblings, uint16_t position) { + array_desc_->siblings = std::move(siblings); + array_desc_->position = position; + } + + /** The outputs of the array's primitive (i.e. this array and + * its siblings) in the order the primitive expects. */ + std::vector outputs() const { + auto idx = array_desc_->position; + std::vector outputs; + outputs.reserve(siblings().size() + 1); + outputs.insert(outputs.end(), siblings().begin(), siblings().begin() + idx); + outputs.push_back(*this); + outputs.insert(outputs.end(), siblings().begin() + idx, siblings().end()); + return outputs; + } + + /** Detach the array from the graph. */ + void detach(); + + /** Get the Flags bit-field. */ + const Flags& flags() const { + return array_desc_->flags; + } + + /** The size (in elements) of the underlying buffer the array points to. + * + * This can be different than the actual size of the array if the array has + * been broadcast or irregularly strided. If ``first`` is the offset into + * the data buffer of the first element of the array (i.e. the offset + * corresponding to ``arr[0, 0, ...]``) and last is the offset into the + * data buffer of the last element of the array (i.e. the offset + * corresponding to ``arr[-1, -1, ...]``) then ``data_size = last - first``. + * Note, ``data_size`` is in units of ``item_size`` (not bytes). + **/ + size_t data_size() const { + return array_desc_->data_size; + } + + allocator::Buffer& buffer() { + return array_desc_->data->buffer; + } + const allocator::Buffer& buffer() const { + return array_desc_->data->buffer; + } + + size_t buffer_size() const { + return allocator::allocator().size(buffer()); + } + + // Return the shared pointer to the array::Data struct + const std::shared_ptr& data_shared_ptr() const { + return array_desc_->data; + } + + // Return a raw pointer to the arrays data. This function may do a copy if + // the underlying buffer is not accessible on the CPU. When accessing the + // data for GPU kernels, be sure to use the correct method / function for the + // given backend to access the GPU pointer. + template + T* data() { + return reinterpret_cast( + (static_cast(buffer().raw_ptr()) + array_desc_->offset)); + } + + template + const T* data() const { + return const_cast(*this).data(); + } + + int64_t offset() const { + return array_desc_->offset; + } + + enum Status { + // The output of a computation which has not been scheduled. + // For example, the status of `x` in `auto x = a + b`. + unscheduled, + + // The array's `eval_*` function has been run, but the computation is not + // necessarily complete. The array will have memory allocated and if it is + // not a tracer then it will be detached from the graph. + evaluated, + + // If the array is the output of a computation then the computation + // is complete. Constant arrays are always available (e.g. `array({1, 2, + // 3})`) + available + }; + + // Check if the array is safe to read. + bool is_available() const; + + // Wait on the array to be available. After this `is_available` returns + // `true`. + void wait(); + + Status status() const { + return array_desc_->status; + } + + void set_status(Status s) const { + array_desc_->status = s; + } + + // Get the array's shared event + Event& event() const { + return array_desc_->event; + } + + // Attach an event to a not yet evaluated array + void attach_event(Event e) const { + array_desc_->event = std::move(e); + } + + void detach_event() const { + array_desc_->event = Event{}; + } + + // Mark the array as a tracer array (true) or not. + void set_tracer(bool is_tracer) { + array_desc_->is_tracer = is_tracer; + } + // Check if the array is a tracer array + bool is_tracer() const; + + void set_data(allocator::Buffer buffer, Deleter d = allocator::free); + + void set_data( + allocator::Buffer buffer, + size_t data_size, + Strides strides, + Flags flags, + Deleter d = allocator::free); + + void copy_shared_buffer( + const array& other, + const Strides& strides, + Flags flags, + size_t data_size, + int64_t offset = 0); + + void copy_shared_buffer(const array& other); + + void overwrite_descriptor(const array& other) { + array_desc_ = other.array_desc_; + } + + ~array(); + + private: + // Initialize the arrays data + template + void init(const It src); + + struct ArrayDesc { + Shape shape; + Strides strides; + size_t size; + Dtype dtype; + std::shared_ptr primitive; + + Status status; + + // An event on the array used for synchronization + Event event; + + // Indicates an array is being used in a graph transform + // and should not be detached from the graph + bool is_tracer{false}; + + // This is a shared pointer so that *different* arrays + // can share the underlying data buffer. + std::shared_ptr data; + + // Offset from beginning of data pointer + int64_t offset{0}; + + // The size in elements of the data buffer the array accesses + size_t data_size; + + // Contains useful meta data about the array + Flags flags; + + std::vector inputs; + // An array to keep track of the siblings from a multi-output + // primitive. + std::vector siblings; + // The arrays position in the output list + uint32_t position{0}; + + explicit ArrayDesc(Shape shape, Dtype dtype); + + explicit ArrayDesc( + Shape shape, + Dtype dtype, + std::shared_ptr primitive, + std::vector inputs); + + ~ArrayDesc(); + + private: + // Initialize size, strides, and other metadata + void init(); + }; + + // The ArrayDesc contains the details of the materialized array including the + // shape, strides, the data type. It also includes + // the primitive which knows how to compute the array's data from its inputs + // and the list of array's inputs for the primitive. + std::shared_ptr array_desc_; +}; + +template +array::array(T val, Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared(Shape{}, dtype)) { + init(&val); +} + +template +array::array( + It data, + Shape shape, + Dtype dtype /* = TypeToDtype::value_type>() */) : + array_desc_(std::make_shared(std::move(shape), dtype)) { + init(data); +} + +template +array::array( + std::initializer_list data, + Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared( + Shape{static_cast(data.size())}, + dtype)) { + init(data.begin()); +} + +template +array::array( + std::initializer_list data, + Shape shape, + Dtype dtype /* = TypeToDtype() */) + : array_desc_(std::make_shared(std::move(shape), dtype)) { + if (data.size() != size()) { + throw std::invalid_argument( + "Data size and provided shape mismatch in array construction."); + } + init(data.begin()); +} + +template +T array::item() { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + eval(); + return *data(); +} + +template +T array::item() const { + if (size() != 1) { + throw std::invalid_argument("item can only be called on arrays of size 1."); + } + if (status() == Status::unscheduled) { + throw std::invalid_argument( + "item() const can only be called on evaled arrays"); + } + const_cast(this)->eval(); + return *data(); +} + +template +void array::init(It src) { + set_data(allocator::malloc(size() * size_of(dtype()))); + switch (dtype()) { + case bool_: + std::copy(src, src + size(), data()); + break; + case uint8: + std::copy(src, src + size(), data()); + break; + case uint16: + std::copy(src, src + size(), data()); + break; + case uint32: + std::copy(src, src + size(), data()); + break; + case uint64: + std::copy(src, src + size(), data()); + break; + case int8: + std::copy(src, src + size(), data()); + break; + case int16: + std::copy(src, src + size(), data()); + break; + case int32: + std::copy(src, src + size(), data()); + break; + case int64: + std::copy(src, src + size(), data()); + break; + case float16: + std::copy(src, src + size(), data()); + break; + case float32: + std::copy(src, src + size(), data()); + break; + case float64: + std::copy(src, src + size(), data()); + break; + case bfloat16: + std::copy(src, src + size(), data()); + break; + case complex64: + std::copy(src, src + size(), data()); + break; + } +} + +/* Utilities for determining whether a template parameter is array. */ +template +inline constexpr bool is_array_v = + std::is_same_v>, array>; + +template +inline constexpr bool is_arrays_v = (is_array_v && ...); + +template +using enable_for_arrays_t = typename std::enable_if_t>; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/binary.h b/dist/include/mlx/backend/common/binary.h new file mode 100644 index 0000000..78607ef --- /dev/null +++ b/dist/include/mlx/backend/common/binary.h @@ -0,0 +1,97 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +enum class BinaryOpType { + ScalarScalar, + ScalarVector, + VectorScalar, + VectorVector, + General, +}; + +inline BinaryOpType get_binary_op_type(const array& a, const array& b) { + BinaryOpType bopt; + if (a.data_size() == 1 && b.data_size() == 1) { + bopt = BinaryOpType::ScalarScalar; + } else if (a.data_size() == 1 && b.flags().contiguous) { + bopt = BinaryOpType::ScalarVector; + } else if (b.data_size() == 1 && a.flags().contiguous) { + bopt = BinaryOpType::VectorScalar; + } else if ( + (a.flags().row_contiguous && b.flags().row_contiguous) || + (a.flags().col_contiguous && b.flags().col_contiguous)) { + bopt = BinaryOpType::VectorVector; + } else { + bopt = BinaryOpType::General; + } + return bopt; +} + +inline void set_binary_op_output_data( + const array& a, + const array& b, + array& out, + BinaryOpType bopt, + std::function mallocfn = allocator::malloc) { + bool b_donatable = is_donatable(b, out); + bool a_donatable = is_donatable(a, out); + switch (bopt) { + case BinaryOpType::ScalarScalar: + out.set_data(mallocfn(out.itemsize()), 1, a.strides(), a.flags()); + break; + case BinaryOpType::ScalarVector: + if (b_donatable) { + out.copy_shared_buffer(b); + } else { + out.set_data( + mallocfn(b.data_size() * out.itemsize()), + b.data_size(), + b.strides(), + b.flags()); + } + break; + case BinaryOpType::VectorScalar: + if (a_donatable) { + out.copy_shared_buffer(a); + } else { + out.set_data( + mallocfn(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + } + break; + case BinaryOpType::VectorVector: + if (a_donatable) { + out.copy_shared_buffer(a); + } else if (b_donatable) { + out.copy_shared_buffer(b); + } else { + out.set_data( + mallocfn(a.data_size() * out.itemsize()), + a.data_size(), + a.strides(), + a.flags()); + } + break; + case BinaryOpType::General: + if (a_donatable && a.flags().row_contiguous && a.size() == out.size()) { + out.copy_shared_buffer(a); + } else if ( + b_donatable && b.flags().row_contiguous && b.size() == out.size()) { + out.copy_shared_buffer(b); + } else { + out.set_data(mallocfn(out.nbytes())); + } + break; + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/broadcasting.h b/dist/include/mlx/backend/common/broadcasting.h new file mode 100644 index 0000000..29651e9 --- /dev/null +++ b/dist/include/mlx/backend/common/broadcasting.h @@ -0,0 +1,11 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void broadcast(const array& in, array& out); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/buffer_cache.h b/dist/include/mlx/backend/common/buffer_cache.h new file mode 100644 index 0000000..92b20f2 --- /dev/null +++ b/dist/include/mlx/backend/common/buffer_cache.h @@ -0,0 +1,157 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core { + +template +class BufferCache { + public: + BufferCache( + size_t page_size, + std::function get_size, + std::function free) + : page_size_(page_size), + get_size_(std::move(get_size)), + free_(std::move(free)) {} + + ~BufferCache() { + clear(); + } + + BufferCache(const BufferCache&) = delete; + BufferCache& operator=(const BufferCache&) = delete; + + T* reuse_from_cache(size_t size) { + // Find the closest buffer in pool. + auto it = buffer_pool_.lower_bound(size); + if (it == buffer_pool_.end() || + it->first >= std::min(2 * size, size + 2 * page_size_)) { + return nullptr; + } + + // Collect from the cache. + T* buf = it->second->buf; + pool_size_ -= it->first; + + // Remove from record. + remove_from_list(it->second); + buffer_pool_.erase(it); + return buf; + } + + void recycle_to_cache(T* buf) { + assert(buf); + // Add to cache. + BufferHolder* bh = new BufferHolder(buf); + add_at_head(bh); + size_t size = get_size_(buf); + pool_size_ += size; + buffer_pool_.emplace(size, bh); + } + + int release_cached_buffers(size_t min_bytes_to_free) { + if (min_bytes_to_free >= 0.9 * pool_size_) { + return clear(); + } else { + int n_release = 0; + size_t total_bytes_freed = 0; + + while (tail_ && (total_bytes_freed < min_bytes_to_free)) { + // Release buffer. + size_t size = get_size_(tail_->buf); + total_bytes_freed += size; + free_(tail_->buf); + n_release++; + + // Remove from record. + auto its = buffer_pool_.equal_range(size); + auto it = std::find_if(its.first, its.second, [this](const auto& el) { + return el.second == tail_; + }); + assert(it != buffer_pool_.end()); + buffer_pool_.erase(it); + remove_from_list(tail_); + } + + pool_size_ -= total_bytes_freed; + return n_release; + } + } + + int clear() { + int n_release = 0; + for (auto& [size, holder] : buffer_pool_) { + free_(holder->buf); + n_release++; + delete holder; + } + buffer_pool_.clear(); + pool_size_ = 0; + head_ = nullptr; + tail_ = nullptr; + return n_release; + } + + size_t cache_size() const { + return pool_size_; + } + + size_t page_size() const { + return page_size_; + } + + private: + struct BufferHolder { + public: + explicit BufferHolder(T* buf_) : buf(buf_) {} + + BufferHolder* prev{nullptr}; + BufferHolder* next{nullptr}; + T* buf; + }; + + void add_at_head(BufferHolder* to_add) { + if (!head_) { + head_ = to_add; + tail_ = to_add; + } else { + head_->prev = to_add; + to_add->next = head_; + head_ = to_add; + } + } + + void remove_from_list(BufferHolder* to_remove) { + if (to_remove->prev && to_remove->next) { // if middle + to_remove->prev->next = to_remove->next; + to_remove->next->prev = to_remove->prev; + } else if (to_remove->prev && to_remove == tail_) { // if tail + tail_ = to_remove->prev; + tail_->next = nullptr; + } else if (to_remove == head_ && to_remove->next) { // if head + head_ = to_remove->next; + head_->prev = nullptr; + } else if (to_remove == head_ && to_remove == tail_) { // if only element + head_ = nullptr; + tail_ = nullptr; + } + + delete to_remove; + } + + std::multimap buffer_pool_; + BufferHolder* head_{nullptr}; + BufferHolder* tail_{nullptr}; + size_t pool_size_{0}; + + const size_t page_size_; + std::function get_size_; + std::function free_; +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/compiled.h b/dist/include/mlx/backend/common/compiled.h new file mode 100644 index 0000000..3be3713 --- /dev/null +++ b/dist/include/mlx/backend/common/compiled.h @@ -0,0 +1,77 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +#include +#include + +#include "mlx/array.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +inline bool is_static_cast(const Primitive& p) { + return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); +} + +std::string get_type_string(Dtype d); + +template +void print_float_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + if constexpr (std::is_same_v) { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } else { + os << std::setprecision(std::numeric_limits::digits10 + 1); + } + os << x.item() << std::setprecision(old_precision); +} + +template +void print_int_constant(std::ostream& os, const array& x) { + os << x.item(); +} + +template +void print_complex_constant(std::ostream& os, const array& x) { + auto old_precision = os.precision(); + T constant = x.item(); + + os << get_type_string(x.dtype()) << "(" + << std::setprecision(std::numeric_limits::digits10 + 1) + << constant.real() << ", " << constant.imag() << ")" + << std::setprecision(old_precision); +} + +void print_constant(std::ostream& os, const array& x); + +inline bool is_scalar(const array& x) { + return x.ndim() == 0; +} + +// Check if we can use a contiguous operation given inputs and the output shape +bool compiled_check_contiguity( + const std::vector& inputs, + const Shape& shape); + +// Allocate space for the outputs possibly with input donation +void compiled_allocate_outputs( + const std::vector& inputs, + std::vector& outputs, + const std::function& is_constant, + bool contiguous, + const std::function& mallocfn = + allocator::malloc); + +// Collapse contiguous dims ignoring scalars and constants. +std::tuple> compiled_collapse_contiguous_dims( + const std::vector& inputs, + const array& out, + const std::function& is_constant); + +// Return whether the kernel should use large index. +bool compiled_use_large_index( + const std::vector& inputs, + const std::vector& outputs, + bool contiguous); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/copy.h b/dist/include/mlx/backend/common/copy.h new file mode 100644 index 0000000..859ce04 --- /dev/null +++ b/dist/include/mlx/backend/common/copy.h @@ -0,0 +1,50 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +enum class CopyType { + // Copy a raw scalar input into the full contiguous output + Scalar, + + // Copy the raw input buffer contiguously into a raw output buffer of the same + // size + Vector, + + // Copy the full virtual input to the full contiguous output + General, + + // Copy the full virtual input to the full virtual output. We assume the + // input and output have the same shape. + GeneralGeneral +}; + +inline bool set_copy_output_data( + const array& in, + array& out, + CopyType ctype, + std::function mallocfn = allocator::malloc) { + if (ctype == CopyType::Vector) { + // If the input is donateable, we are doing a vector copy and the types + // have the same size, then the input buffer can hold the output. + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + return true; + } else { + out.set_data( + mallocfn(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + return false; + } + } else { + out.set_data(mallocfn(out.nbytes())); + return false; + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/hadamard.h b/dist/include/mlx/backend/common/hadamard.h new file mode 100644 index 0000000..ba5c4e4 --- /dev/null +++ b/dist/include/mlx/backend/common/hadamard.h @@ -0,0 +1,109 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/utils.h" + +namespace mlx::core { + +// From http://neilsloane.com/hadamard/ +constexpr std::string_view h12 = R"( ++-++++++++++ +--+-+-+-+-+- ++++-++----++ ++---+--+-++- ++++++-++---- ++-+---+--+-+ +++--+++-++-- ++--++---+--+ +++----+++-++ ++--+-++---+- +++++----+++- ++-+--+-++--- +)"; + +constexpr std::string_view h20 = R"( ++----+----++--++-++- +-+----+---+++---+-++ +--+----+---+++-+-+-+ +---+----+---+++++-+- +----+----++--++-++-+ +-+++++-----+--+++--+ ++-+++-+---+-+--+++-- +++-++--+---+-+--+++- ++++-+---+---+-+--+++ +++++-----++--+-+--++ +--++-+-++-+-----++++ +---++-+-++-+---+-+++ ++---++-+-+--+--++-++ +++---++-+----+-+++-+ +-++---++-+----+++++- +-+--+--++-+----+---- ++-+-----++-+----+--- +-+-+-+---+--+----+-- +--+-+++------+----+- ++--+--++------+----+ +)"; + +constexpr std::string_view h28 = R"( ++------++----++-+--+-+--++-- +-+-----+++-----+-+--+-+--++- +--+-----+++---+-+-+----+--++ +---+-----+++---+-+-+-+--+--+ +----+-----+++---+-+-+++--+-- +-----+-----++++--+-+--++--+- +------++----++-+--+-+--++--+ +--++++-+-------++--+++-+--+- +---++++-+-----+-++--+-+-+--+ ++---+++--+----++-++--+-+-+-- +++---++---+----++-++--+-+-+- ++++---+----+----++-++--+-+-+ +++++--------+-+--++-++--+-+- +-++++--------+++--++--+--+-+ +-+-++-++--++--+--------++++- ++-+-++--+--++--+--------++++ +-+-+-++--+--++--+----+---+++ ++-+-+-++--+--+---+---++---++ +++-+-+-++--+------+--+++---+ +-++-+-+-++--+------+-++++--- ++-++-+---++--+------+-++++-- +-++--++-+-++-+++----++------ ++-++--++-+-++-+++-----+----- +++-++---+-+-++-+++-----+---- +-++-++-+-+-+-+--+++-----+--- +--++-++++-+-+----+++-----+-- ++--++-+-++-+-+----+++-----+- +++--++-+-++-+-+----++------+ +)"; + +inline const std::map hadamard_matrices() { + return {{12, h12}, {20, h20}, {28, h28}}; +} + +inline std::pair decompose_hadamard(int n) { + // n = m*2^k + int m = 1; + if (!is_power_of_2(n)) { + auto h_matrices = hadamard_matrices(); + for (auto [factor, _] : h_matrices) { + if (n % factor == 0) { + m = factor; + n /= factor; + break; + } + } + if (m == 1) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where m in (1, 12, 20, 28)."); + } + } + if (n > (1 << 26)) { + throw std::invalid_argument( + "[hadamard] Only supports n = m*2^k where k <= 26"); + } + return {n, m}; +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/matmul.h b/dist/include/mlx/backend/common/matmul.h new file mode 100644 index 0000000..2545c4f --- /dev/null +++ b/dist/include/mlx/backend/common/matmul.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" +#include "mlx/utils.h" + +#include + +namespace mlx::core { + +inline std::tuple collapse_batches( + const array& a, + const array& b) { + if (a.ndim() == 2) { + return {Shape{1}, Strides{0}, Strides{0}}; + } + + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + + auto [batch_shape, batch_strides] = + collapse_contiguous_dims(A_bshape, std::vector{A_bstride, B_bstride}); + + auto a_batch_strides = batch_strides[0]; + auto b_batch_strides = batch_strides[1]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + a_batch_strides.push_back(0); + b_batch_strides.push_back(0); + } + + return std::make_tuple(batch_shape, a_batch_strides, b_batch_strides); +} + +inline std::tuple +collapse_batches(const array& a, const array& b, const array& c) { + if (a.ndim() == 2) { + return {Shape{1}, Strides{0}, Strides{0}, Strides{0}}; + } + + Shape A_bshape{a.shape().begin(), a.shape().end() - 2}; + Strides A_bstride{a.strides().begin(), a.strides().end() - 2}; + Strides B_bstride{b.strides().begin(), b.strides().end() - 2}; + Strides C_bstride{c.strides().begin(), c.strides().end() - 2}; + + auto [batch_shape, batch_strides] = collapse_contiguous_dims( + A_bshape, std::vector{A_bstride, B_bstride, C_bstride}); + + auto A_batch_stride = batch_strides[0]; + auto B_batch_stride = batch_strides[1]; + auto C_batch_stride = batch_strides[2]; + + if (batch_shape.empty()) { + batch_shape.push_back(1); + A_batch_stride.push_back(0); + B_batch_stride.push_back(0); + C_batch_stride.push_back(0); + } + + return std::make_tuple( + batch_shape, A_batch_stride, B_batch_stride, C_batch_stride); +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/reduce.h b/dist/include/mlx/backend/common/reduce.h new file mode 100644 index 0000000..8b24f4f --- /dev/null +++ b/dist/include/mlx/backend/common/reduce.h @@ -0,0 +1,59 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +enum ReductionOpType { + // Self-explanatory. Read everything and produce 1 output. + ContiguousAllReduce, + + // The input is contiguous and the last axis is reduced + // N1xR1xN2xR2x...xNnxRn + ContiguousReduce, + + // The input is contiguous and the last axis is not reduced + // R1xN1xR2xN2x...xRnxNn + ContiguousStridedReduce, + + // The input is not contiguous but the last axis is and it is reduced so we + // need to figure out the offsets but we can call the contiguous reduce after + // that. + // N3xR1xN1xR4x...xRn + GeneralContiguousReduce, + + // The input is not contiguous but the last reduction axis and the last axis + // are so we need to figure out the offset but we can call the strided reduce + // after that. + GeneralStridedReduce, + + // The input is not contiguous after the reduction axis and it may contain + // 0-stride axes or transpositions. We could copy the strides and produce a + // transposed outcome or we can read the input out of order and write the + // output in order. + GeneralReduce +}; + +struct ReductionPlan { + ReductionOpType type; + Shape shape; + Strides strides; + + ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_) + : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {} + ReductionPlan(ReductionOpType type_) : type(type_) {} +}; + +ReductionPlan get_reduction_plan(const array& x, const std::vector& axes); + +std::pair shapes_without_reduction_axes( + const array& x, + const std::vector& axes); +std::pair shapes_without_reduction_axes( + Shape shape, + Strides strides, + const std::vector& axes); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/slicing.h b/dist/include/mlx/backend/common/slicing.h new file mode 100644 index 0000000..b667d26 --- /dev/null +++ b/dist/include/mlx/backend/common/slicing.h @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +std::tuple prepare_slice( + const array& in, + const Shape& start_indices, + const Shape& strides); + +void slice( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/ternary.h b/dist/include/mlx/backend/common/ternary.h new file mode 100644 index 0000000..c63a572 --- /dev/null +++ b/dist/include/mlx/backend/common/ternary.h @@ -0,0 +1,85 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +// TODO: Add support for more combinations of input types. +enum class TernaryOpType { + ScalarScalarScalar, + VectorVectorVector, + VectorVectorScalar, + VectorScalarVector, + General, +}; + +inline TernaryOpType +get_ternary_op_type(const array& a, const array& b, const array& c) { + TernaryOpType topt; + if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) { + topt = TernaryOpType::ScalarScalarScalar; + } else if ( + (a.flags().row_contiguous && b.flags().row_contiguous && + c.flags().row_contiguous) || + (a.flags().col_contiguous && b.flags().col_contiguous && + c.flags().col_contiguous)) { + topt = TernaryOpType::VectorVectorVector; + } else if ( + b.data_size() == 1 && a.flags().row_contiguous && + c.flags().row_contiguous) { + topt = TernaryOpType::VectorScalarVector; + } else if ( + c.data_size() == 1 && a.flags().row_contiguous && + b.flags().row_contiguous) { + topt = TernaryOpType::VectorVectorScalar; + } else { + topt = TernaryOpType::General; + } + return topt; +} + +inline void set_ternary_op_output_data( + const array& a, + const array& b, + const array& c, + array& out, + TernaryOpType topt, + std::function mallocfn = allocator::malloc) { + auto maybe_donate = [&out](const array& x) { + if (is_donatable(x, out)) { + out.copy_shared_buffer(x); + return true; + } + return false; + }; + + switch (topt) { + case TernaryOpType::ScalarScalarScalar: + out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags()); + break; + case TernaryOpType::VectorVectorVector: + if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) { + out.set_data( + mallocfn(out.itemsize() * b.data_size()), + b.data_size(), + b.strides(), + b.flags()); + } + break; + case TernaryOpType::VectorVectorScalar: + case TernaryOpType::VectorScalarVector: + case TernaryOpType::General: + // Try to donate an input which is row_contiguous + if (!((a.flags().row_contiguous && maybe_donate(a)) || + (b.flags().row_contiguous && maybe_donate(b)) || + (c.flags().row_contiguous && maybe_donate(c)))) { + out.set_data(mallocfn(out.nbytes())); + } + break; + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/unary.h b/dist/include/mlx/backend/common/unary.h new file mode 100644 index 0000000..b19fc98 --- /dev/null +++ b/dist/include/mlx/backend/common/unary.h @@ -0,0 +1,29 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +inline void set_unary_output_data( + const array& in, + array& out, + std::function mallocfn = allocator::malloc) { + if (in.flags().contiguous) { + if (is_donatable(in, out)) { + out.copy_shared_buffer(in); + } else { + out.set_data( + mallocfn(in.data_size() * out.itemsize()), + in.data_size(), + in.strides(), + in.flags()); + } + } else { + out.set_data(mallocfn(out.nbytes())); + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/common/utils.h b/dist/include/mlx/backend/common/utils.h new file mode 100644 index 0000000..1b6902f --- /dev/null +++ b/dist/include/mlx/backend/common/utils.h @@ -0,0 +1,205 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/array.h" + +namespace mlx::core { + +// Return the directory that contains current shared library. +std::filesystem::path current_binary_dir(); + +inline int64_t +elem_to_loc(int elem, const Shape& shape, const Strides& strides) { + int64_t loc = 0; + for (int i = shape.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(elem, shape[i]); + loc += q_and_r.rem * strides[i]; + elem = q_and_r.quot; + } + return loc; +} + +inline int64_t elem_to_loc(int elem, const array& a) { + if (a.flags().row_contiguous) { + return elem; + } + return elem_to_loc(elem, a.shape(), a.strides()); +} + +inline Strides make_contiguous_strides(const Shape& shape) { + Strides strides(shape.size(), 1); + for (int i = shape.size() - 1; i > 0; i--) { + strides[i - 1] = strides[i] * shape[i]; + } + return strides; +} + +// Collapse dims that are contiguous to possibly route to a better kernel +// e.g. for x = transpose(array({0, 1, 2, 3, 4, 5, 6, 7}, {2, 2, 2}), {2, 0, 1}) +// should return {{2, 4}, {{1, 2}}}. +// +// When multiple arrays are passed they should all have the same shape. The +// collapsed axes are also the same so one shape is returned. +std::tuple> collapse_contiguous_dims( + const Shape& shape, + const std::vector& strides, + int64_t size_cap = std::numeric_limits::max()); + +inline std::tuple> collapse_contiguous_dims( + const std::vector& xs, + size_t size_cap = std::numeric_limits::max()) { + std::vector strides; + for (auto& x : xs) { + strides.emplace_back(x.strides()); + } + return collapse_contiguous_dims(xs[0].shape(), strides, size_cap); +} + +template > +inline auto collapse_contiguous_dims(Arrays&&... xs) { + return collapse_contiguous_dims( + std::vector{std::forward(xs)...}); +} + +// The single array version of the above. +std::pair collapse_contiguous_dims( + const Shape& shape, + const Strides& strides, + int64_t size_cap = std::numeric_limits::max()); +std::pair collapse_contiguous_dims( + const array& a, + int64_t size_cap = std::numeric_limits::max()); + +// Compute the thread block dimensions which fit the given +// input dimensions. +// - The thread block dimensions will be powers of two +// - The thread block size will be less than 2^pow2 +using Dims = std::tuple; +Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 = 10); + +// Computes a 2D grid where each element is < UINT_MAX +// Assumes: +// - overall size (product of non-broadcasted dimensions) is < UINT_MAX^2 +// - shape and strides correspond to a contiguous (no holes) but +// possibly broadcasted array +Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides); + +// Same as above but we do an implicit division with divisor. +// Basically, equivalent to factorizing +// Prod(s \forall s in shape if strides[s] > 0) / divisor. +Dims get_2d_grid_dims_common( + const Shape& shape, + const Strides& strides, + size_t divisor); + +// Get both the block and a grid of blocks that covers dim0, dim1 and dim2. +std::pair get_grid_and_block_common(int dim0, int dim1, int dim2); + +struct ContiguousIterator { + inline void step() { + int dims = shape_.size(); + if (dims == 0) { + return; + } + int i = dims - 1; + while (pos_[i] == (shape_[i] - 1) && i > 0) { + pos_[i] = 0; + loc -= (shape_[i] - 1) * strides_[i]; + i--; + } + pos_[i]++; + loc += strides_[i]; + } + + void seek(int64_t n) { + loc = 0; + for (int i = shape_.size() - 1; i >= 0; --i) { + auto q_and_r = ldiv(n, shape_[i]); + loc += q_and_r.rem * strides_[i]; + pos_[i] = q_and_r.rem; + n = q_and_r.quot; + } + } + + void reset() { + loc = 0; + std::fill(pos_.begin(), pos_.end(), 0); + } + + ContiguousIterator() {}; + + explicit ContiguousIterator(const array& a) + : shape_(a.shape()), strides_(a.strides()) { + if (!shape_.empty()) { + std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); + pos_ = Shape(shape_.size(), 0); + } + } + + explicit ContiguousIterator( + const Shape& shape, + const Strides& strides, + int dims) + : shape_(shape.begin(), shape.begin() + dims), + strides_(strides.begin(), strides.begin() + dims) { + if (!shape_.empty()) { + std::tie(shape_, strides_) = collapse_contiguous_dims(shape_, strides_); + pos_ = Shape(shape_.size(), 0); + } + } + + int64_t loc{0}; + + private: + Shape shape_; + Strides strides_; + Shape pos_; +}; + +inline auto check_contiguity(const Shape& shape, const Strides& strides) { + size_t no_broadcast_data_size = 1; + int64_t f_stride = 1; + int64_t b_stride = 1; + bool is_row_contiguous = true; + bool is_col_contiguous = true; + + for (int i = 0, ri = shape.size() - 1; ri >= 0; i++, ri--) { + is_col_contiguous &= strides[i] == f_stride || shape[i] == 1; + is_row_contiguous &= strides[ri] == b_stride || shape[ri] == 1; + f_stride *= shape[i]; + b_stride *= shape[ri]; + if (strides[i] > 0) { + no_broadcast_data_size *= shape[i]; + } + } + + return std::make_tuple( + no_broadcast_data_size, is_row_contiguous, is_col_contiguous); +} + +inline bool is_donatable(const array& in, const array& out) { + constexpr size_t donation_extra = 16384; + + return in.is_donatable() && in.itemsize() == out.itemsize() && + in.buffer_size() <= out.nbytes() + donation_extra; +} + +std::pair prepare_reshape(const array& in, const array& out); + +void shared_buffer_reshape( + const array& in, + const Strides& out_strides, + array& out); + +template +inline SmallVector remove_index(SmallVector vec, size_t index) { + vec.erase(std::next(vec.begin(), index)); + return vec; +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/arange.h b/dist/include/mlx/backend/cpu/arange.h new file mode 100644 index 0000000..9e9b03b --- /dev/null +++ b/dist/include/mlx/backend/cpu/arange.h @@ -0,0 +1,28 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cpu/encoder.h" + +namespace mlx::core { + +namespace { + +template +void arange(T start, T next, array& out, size_t size, Stream stream) { + auto ptr = out.data(); + auto step_size = next - start; + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_output_array(out); + encoder.dispatch([ptr, start, step_size, size]() mutable { + for (int i = 0; i < size; ++i) { + ptr[i] = start; + start += step_size; + } + }); +} + +} // namespace + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/available.h b/dist/include/mlx/backend/cpu/available.h new file mode 100644 index 0000000..1df95de --- /dev/null +++ b/dist/include/mlx/backend/cpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cpu { + +bool is_available(); + +} // namespace mlx::core::cpu diff --git a/dist/include/mlx/backend/cpu/binary.h b/dist/include/mlx/backend/cpu/binary.h new file mode 100644 index 0000000..acaca50 --- /dev/null +++ b/dist/include/mlx/backend/cpu/binary.h @@ -0,0 +1,517 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include + +#include "mlx/array.h" +#include "mlx/backend/common/binary.h" +#include "mlx/backend/common/utils.h" + +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core { + +template +struct VectorScalar { + template + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *b; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, Op{}(simd::load(a), simd::Simd(scalar))); + dst += N; + a += N; + size -= N; + } + while (size-- > 0) { + *dst = Op{}(*a, scalar); + dst++; + a++; + } + } +}; + +template +struct ScalarVector { + template + void operator()(const T* a, const T* b, U* dst, int size) { + T scalar = *a; + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, Op{}(simd::Simd(scalar), simd::load(b))); + dst += N; + b += N; + size -= N; + } + while (size-- > 0) { + *dst = Op{}(scalar, *b); + dst++; + b++; + } + } +}; + +template +struct VectorVector { + template + void operator()(const T* a, const T* b, U* dst, int size) { + constexpr int N = simd::max_size; + while (size >= N) { + simd::store(dst, Op{}(simd::load(a), simd::load(b))); + dst += N; + a += N; + b += N; + size -= N; + } + while (size-- > 0) { + *dst = Op{}(*a, *b); + dst++; + a++; + b++; + } + } +}; + +template +void binary_op_dims( + const T* a, + const T* b, + U* out, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, axis + 1); + } else { + if constexpr (Strided) { + Op{}(a, b, out, stride_out); + } else { + *out = Op{}(*a, *b); + } + } + out += stride_out; + a += stride_a; + b += stride_b; + } +} + +template +void binary_op_dispatch_dims( + const T* a, + const T* b, + U* out, + int dim, + int size, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides) { + switch (dim) { + case 1: + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, 0); + return; + case 2: + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, 0); + return; + case 3: + binary_op_dims( + a, b, out, shape, a_strides, b_strides, out_strides, 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, dim - 3); + ContiguousIterator b_it(shape, b_strides, dim - 3); + auto stride = out_strides[dim - 4]; + for (int64_t elem = 0; elem < size; elem += stride) { + binary_op_dims( + a + a_it.loc, + b + b_it.loc, + out + elem, + shape, + a_strides, + b_strides, + out_strides, + dim - 3); + a_it.step(); + b_it.step(); + } +} + +template +void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { + // The full computation is scalar scalar so call the base op once + auto a_ptr = a.data(); + auto b_ptr = b.data(); + + auto out_ptr = out.data(); + if (bopt == BinaryOpType::ScalarScalar) { + *out_ptr = Op{}(*a_ptr, *b_ptr); + return; + } + + // The full computation is scalar vector so delegate to the op + if (bopt == BinaryOpType::ScalarVector) { + ScalarVector{}(a_ptr, b_ptr, out_ptr, b.data_size()); + return; + } + + // The full computation is vector scalar so delegate to the op + if (bopt == BinaryOpType::VectorScalar) { + VectorScalar{}(a_ptr, b_ptr, out_ptr, a.data_size()); + return; + } + + // The full computation is vector vector so delegate to the op + if (bopt == BinaryOpType::VectorVector) { + VectorVector{}(a_ptr, b_ptr, out_ptr, a.size()); + return; + } + + // General computation so let's try to optimize + auto [new_shape, new_strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out.strides()}); + auto& a_strides = new_strides[0]; + auto& b_strides = new_strides[1]; + auto& strides = new_strides[2]; + + // Get the left-most dim such that the array is row contiguous after + auto leftmost_rc_dim = [&strides](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == strides[d]; d--) { + } + return d + 1; + }; + auto a_rc_dim = leftmost_rc_dim(a_strides); + auto b_rc_dim = leftmost_rc_dim(b_strides); + + // Get the left-most dim such that the array is a broadcasted "scalar" after + auto leftmost_s_dim = [](const auto& arr_strides) { + int d = arr_strides.size() - 1; + for (; d >= 0 && arr_strides[d] == 0; d--) { + } + return d + 1; + }; + auto a_s_dim = leftmost_s_dim(a_strides); + auto b_s_dim = leftmost_s_dim(b_strides); + + auto ndim = new_shape.size(); + + // Case 1: LxM and FxM where L and F are broadcastable and M is row + // contiguous + int dim = ndim; + if (int d = std::max(a_rc_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::VectorVector; + dim = d; + // Case 2: LxM and Fx1 where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_rc_dim, b_s_dim); d < ndim) { + bopt = BinaryOpType::VectorScalar; + dim = d; + // Case 3: Lx1 and FxM where L and F are broadcastable and M is row + // contiguous + } else if (int d = std::max(a_s_dim, b_rc_dim); d < ndim) { + bopt = BinaryOpType::ScalarVector; + dim = d; + } + + // Can be sure dim > 0 since otherwise we would have used one of the fully + // contiguous methods above. Except for the case that the flags do not + // correspond to the underlying contiguity. + if (dim == 0 || strides[dim - 1] < 16) { + bopt = BinaryOpType::General; + dim = ndim; + } + + switch (bopt) { + case BinaryOpType::VectorVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::VectorScalar: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + case BinaryOpType::ScalarVector: + binary_op_dispatch_dims>( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + default: + binary_op_dispatch_dims( + a_ptr, + b_ptr, + out_ptr, + dim, + a.size(), + new_shape, + a_strides, + b_strides, + strides); + break; + } +} + +template +void binary_op(const array& a, const array& b, array& out, BinaryOpType bopt) { + binary_op(a, b, out, bopt); +} + +template +void binary_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void comparison_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (a.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + break; + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + } + }); +} + +template +void binary_float_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case float16: + binary_op(a, b, out, bopt); + break; + case float32: + binary_op(a, b, out, bopt); + break; + case float64: + binary_op(a, b, out, bopt); + break; + case bfloat16: + binary_op(a, b, out, bopt); + break; + case complex64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error( + "[binary_float] Only supports floating point types."); + } + }); +} + +template +void binary_int_op_cpu( + const array& a, + const array& b, + array& out, + Op op, + Stream stream) { + auto bopt = get_binary_op_type(a, b); + set_binary_op_output_data(a, b, out, bopt); + + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_input_array(b); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + b = array::unsafe_weak_copy(b), + out = array::unsafe_weak_copy(out), + bopt]() mutable { + switch (out.dtype()) { + case bool_: + binary_op(a, b, out, bopt); + case uint8: + binary_op(a, b, out, bopt); + break; + case uint16: + binary_op(a, b, out, bopt); + break; + case uint32: + binary_op(a, b, out, bopt); + break; + case uint64: + binary_op(a, b, out, bopt); + break; + case int8: + binary_op(a, b, out, bopt); + break; + case int16: + binary_op(a, b, out, bopt); + break; + case int32: + binary_op(a, b, out, bopt); + break; + case int64: + binary_op(a, b, out, bopt); + break; + default: + throw std::runtime_error("[binary_int] Type not supported"); + break; + } + }); +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/binary_ops.h b/dist/include/mlx/backend/cpu/binary_ops.h new file mode 100644 index 0000000..d50751c --- /dev/null +++ b/dist/include/mlx/backend/cpu/binary_ops.h @@ -0,0 +1,98 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core::detail { + +using namespace mlx::core::simd; + +#define BINARY_SINGLE() \ + template \ + T operator()(T x, T y) { \ + return (*this)(Simd(x), Simd(y)).value; \ + } + +#define DEFAULT_BINARY_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x, Simd y) { \ + return op(x, y); \ + } \ + BINARY_SINGLE() \ + }; + +DEFAULT_BINARY_OP(Add, operator+) +DEFAULT_BINARY_OP(ArcTan2, atan2) +DEFAULT_BINARY_OP(Divide, operator/) +DEFAULT_BINARY_OP(Multiply, operator*) +DEFAULT_BINARY_OP(Subtract, operator-) +DEFAULT_BINARY_OP(LogicalAnd, operator&&) +DEFAULT_BINARY_OP(LogicalOr, operator||) +DEFAULT_BINARY_OP(BitwiseAnd, operator&) +DEFAULT_BINARY_OP(BitwiseOr, operator|) +DEFAULT_BINARY_OP(BitwiseXor, operator^) +DEFAULT_BINARY_OP(LeftShift, operator<<) +DEFAULT_BINARY_OP(RightShift, operator>>) +DEFAULT_BINARY_OP(Remainder, remainder) +DEFAULT_BINARY_OP(Maximum, maximum) +DEFAULT_BINARY_OP(Minimum, minimum) +DEFAULT_BINARY_OP(Power, pow) + +#define DEFAULT_BOOL_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x, Simd y) { \ + return op(x, y); \ + } \ + template \ + bool operator()(T x, T y) { \ + return (*this)(Simd(x), Simd(y)).value; \ + } \ + }; + +DEFAULT_BOOL_OP(Equal, operator==) +DEFAULT_BOOL_OP(Greater, operator>) +DEFAULT_BOOL_OP(GreaterEqual, operator>=) +DEFAULT_BOOL_OP(Less, operator<) +DEFAULT_BOOL_OP(LessEqual, operator<=) +DEFAULT_BOOL_OP(NotEqual, operator!=) + +struct NaNEqual { + template + Simd operator()(Simd x, Simd y) { + return x == y || (isnan(x) && isnan(y)); + } + template + bool operator()(T x, T y) { + return (*this)(Simd(x), Simd(y)).value; + } +}; + +struct LogAddExp { + template + Simd operator()(Simd x, Simd y) { + auto maxval = maximum(x, y); + auto minval = minimum(x, y); + auto mask = minval == -inf || maxval == inf; + auto out = maxval + log1p(exp(minval - maxval)); + return select(mask, Simd(maxval), Simd(out)); + } + BINARY_SINGLE() +}; + +struct Select { + template + T operator()(bool condition, T x, T y) { + return (*this)(Simd(condition), Simd(x), Simd(y)) + .value; + } + + template + Simd operator()(Simd condition, Simd x, Simd y) { + return select(condition, x, y); + } +}; + +} // namespace mlx::core::detail diff --git a/dist/include/mlx/backend/cpu/binary_two.h b/dist/include/mlx/backend/cpu/binary_two.h new file mode 100644 index 0000000..fa0ca79 --- /dev/null +++ b/dist/include/mlx/backend/cpu/binary_two.h @@ -0,0 +1,166 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/binary.h" + +namespace mlx::core { + +namespace { + +template +void binary_op_dims( + const T* a, + const T* b, + U* out_a, + U* out_b, + Op op, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + binary_op_dims( + a, + b, + out_a, + out_b, + op, + shape, + a_strides, + b_strides, + out_strides, + axis + 1); + } else { + std::tie(*out_a, *out_b) = op(*a, *b); + } + a += stride_a; + b += stride_b; + out_a += stride_out; + out_b += stride_out; + } +} + +template +void binary_op_dispatch_dims( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op) { + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), out_a.strides()}); + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + U* out_a_ptr = out_a.data(); + U* out_b_ptr = out_b.data(); + + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& out_strides = strides[2]; + int ndim = shape.size(); + switch (ndim) { + case 1: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + case 2: + binary_op_dims( + a_ptr, + b_ptr, + out_a_ptr, + out_b_ptr, + op, + shape, + a_strides, + b_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < a.size(); elem += stride) { + binary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + out_a_ptr + elem, + out_b_ptr + elem, + op, + shape, + a_strides, + b_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + } +} + +template +void binary_op( + const array& a, + const array& b, + array& out_a, + array& out_b, + Op op, + BinaryOpType bopt) { + // The full computation is scalar scalar so call the base op once + if (bopt == BinaryOpType::General) { + binary_op_dispatch_dims(a, b, out_a, out_b, op); + return; + } + + auto a_ptr = a.data(); + auto b_ptr = b.data(); + auto out_a_ptr = out_a.data(); + auto out_b_ptr = out_b.data(); + if (bopt == BinaryOpType::ScalarScalar) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + } else if (bopt == BinaryOpType::ScalarVector) { + for (size_t i = 0; i < b.data_size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + b_ptr++; + } + } else if (bopt == BinaryOpType::VectorScalar) { + for (size_t i = 0; i < a.data_size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + } + } else { // VectorVector + for (size_t i = 0; i < a.size(); ++i) { + std::tie(*out_a_ptr, *out_b_ptr) = op(*a_ptr, *b_ptr); + out_a_ptr++; + out_b_ptr++; + a_ptr++; + b_ptr++; + } + } +} + +} // namespace + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/compiled_preamble.h b/dist/include/mlx/backend/cpu/compiled_preamble.h new file mode 100644 index 0000000..31ca1b4 --- /dev/null +++ b/dist/include/mlx/backend/cpu/compiled_preamble.h @@ -0,0 +1,12 @@ +// Copyright © 2023-24 Apple Inc. + +#pragma once + +// clang-format off +#include "mlx/types/half_types.h" +#include "mlx/types/complex.h" +#include "mlx/backend/cpu/unary_ops.h" +#include "mlx/backend/cpu/binary_ops.h" +// clang-format on + +const char* get_kernel_preamble(); diff --git a/dist/include/mlx/backend/cpu/copy.h b/dist/include/mlx/backend/cpu/copy.h new file mode 100644 index 0000000..0072913 --- /dev/null +++ b/dist/include/mlx/backend/cpu/copy.h @@ -0,0 +1,36 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/utils.h" + +namespace mlx::core { + +void copy_cpu(const array& src, array& dst, CopyType ctype, Stream stream); +void copy_cpu_inplace( + const array& src, + array& dst, + CopyType ctype, + Stream stream); + +void copy_cpu_inplace( + const array& src, + array& dst, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype, + Stream stream, + const std::optional& dynamic_i_offset = std::nullopt, + const std::optional& dynamic_o_offset = std::nullopt); + +// Return a contiguous array with same shape that copies the data of |arr|. +array contiguous_copy_cpu(const array& arr, Stream stream); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/encoder.h b/dist/include/mlx/backend/cpu/encoder.h new file mode 100644 index 0000000..b8e33ca --- /dev/null +++ b/dist/include/mlx/backend/cpu/encoder.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/scheduler.h" + +namespace mlx::core::cpu { + +// Number of dispatches per scheduler task +constexpr int DISPATCHES_PER_TASK = 10; + +struct CommandEncoder { + CommandEncoder(Stream stream) : stream_(stream) {} + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + CommandEncoder(CommandEncoder&&) = delete; + CommandEncoder& operator=(CommandEncoder&&) = delete; + + void set_input_array(const array& a) {} + void set_output_array(array& a) {} + + // Hold onto a temporary until any already scheduled tasks which use it as + // an input are complete. + void add_temporary(array arr) { + temporaries_.push_back(std::move(arr)); + } + + void add_temporaries(std::vector arrays) { + temporaries_.insert( + temporaries_.end(), + std::make_move_iterator(arrays.begin()), + std::make_move_iterator(arrays.end())); + } + + std::vector& temporaries() { + return temporaries_; + } + + template + void dispatch(F&& f, Args&&... args) { + num_ops_ = (num_ops_ + 1) % DISPATCHES_PER_TASK; + auto task = std::bind(std::forward(f), std::forward(args)...); + if (num_ops_ == 0) { + scheduler::notify_new_task(stream_); + auto task_wrap = [s = stream_, task = std::move(task)]() mutable { + task(); + scheduler::notify_task_completion(s); + }; + scheduler::enqueue(stream_, std::move(task_wrap)); + } else { + scheduler::enqueue(stream_, std::move(task)); + } + } + + private: + Stream stream_; + std::vector temporaries_; + int num_ops_{0}; +}; + +CommandEncoder& get_command_encoder(Stream stream); + +} // namespace mlx::core::cpu diff --git a/dist/include/mlx/backend/cpu/eval.h b/dist/include/mlx/backend/cpu/eval.h new file mode 100644 index 0000000..20156d6 --- /dev/null +++ b/dist/include/mlx/backend/cpu/eval.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::cpu { + +void eval(array& arr); + +} // namespace mlx::core::cpu diff --git a/dist/include/mlx/backend/cpu/gemm.h b/dist/include/mlx/backend/cpu/gemm.h new file mode 100644 index 0000000..d665cb9 --- /dev/null +++ b/dist/include/mlx/backend/cpu/gemm.h @@ -0,0 +1,26 @@ +// Copyright © 2025 Apple Inc. + +#pragma once +#include "mlx/array.h" + +namespace mlx::core { + +template +void matmul( + const T* a, + const T* b, + T* out, + bool a_transposed, + bool b_transposed, + size_t lda, + size_t ldb, + size_t ldc, + float alpha, + float beta, + size_t batch_size, + const Shape& a_shape, + const Strides& a_strides, + const Shape& b_shape, + const Strides& b_strides); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/gemms/simd_gemm.h b/dist/include/mlx/backend/cpu/gemms/simd_gemm.h new file mode 100644 index 0000000..a23c7de --- /dev/null +++ b/dist/include/mlx/backend/cpu/gemms/simd_gemm.h @@ -0,0 +1,139 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core { + +inline int ceildiv(int a, int b) { + return (a + b - 1) / b; +} + +template +void load_block( + const T* in, + AccT* out, + int M, + int N, + int i, + int j, + bool transpose) { + if (transpose) { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[jj * block_size + ii] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } else { + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + out[ii * block_size + jj] = + in[(i * block_size + ii) * N + j * block_size + jj]; + } + } + } +} + +template +void simd_gemm( + const T* a, + const T* b, + T* c, + bool a_trans, + bool b_trans, + int M, + int N, + int K, + float alpha, + float beta) { + constexpr int block_size = 16; + constexpr int simd_size = simd::max_size; + static_assert( + (block_size % simd_size) == 0, + "Block size must be divisible by SIMD size"); + + int last_k_block_size = K - block_size * (K / block_size); + int last_k_simd_block = (last_k_block_size / simd_size) * simd_size; + for (int i = 0; i < ceildiv(M, block_size); i++) { + for (int j = 0; j < ceildiv(N, block_size); j++) { + AccT c_block[block_size * block_size] = {0.0}; + AccT a_block[block_size * block_size]; + AccT b_block[block_size * block_size]; + + int k = 0; + for (; k < K / block_size; k++) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + for (int kk = 0; kk < block_size; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + } + } + } + if (last_k_block_size) { + // Load a and b blocks + if (a_trans) { + load_block(a, a_block, K, M, k, i, true); + } else { + load_block(a, a_block, M, K, i, k, false); + } + if (b_trans) { + load_block(b, b_block, N, K, j, k, false); + } else { + load_block(b, b_block, K, N, k, j, true); + } + + // Multiply and accumulate + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + int kk = 0; + for (; kk < last_k_simd_block; kk += simd_size) { + auto av = + simd::load(a_block + ii * block_size + kk); + auto bv = + simd::load(b_block + jj * block_size + kk); + c_block[ii * block_size + jj] += simd::sum(av * bv); + } + for (; kk < last_k_block_size; ++kk) { + c_block[ii * block_size + jj] += + a_block[ii * block_size + kk] * b_block[jj * block_size + kk]; + } + } + } + } + + // Store + for (int ii = 0; ii < block_size && i * block_size + ii < M; ++ii) { + for (int jj = 0; jj < block_size && j * block_size + jj < N; ++jj) { + auto c_idx = (i * block_size + ii) * N + j * block_size + jj; + if (beta != 0) { + c[c_idx] = static_cast( + alpha * c_block[ii * block_size + jj] + beta * c[c_idx]); + } else { + c[c_idx] = static_cast(alpha * c_block[ii * block_size + jj]); + } + } + } + } + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/jit_compiler.h b/dist/include/mlx/backend/cpu/jit_compiler.h new file mode 100644 index 0000000..3a9e988 --- /dev/null +++ b/dist/include/mlx/backend/cpu/jit_compiler.h @@ -0,0 +1,20 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include + +namespace mlx::core { + +class JitCompiler { + public: + // Build a shell command that compiles a source code file to a shared library. + static std::string build_command( + const std::filesystem::path& dir, + const std::string& source_file_name, + const std::string& shared_lib_name); + + // Run a command and get its output. + static std::string exec(const std::string& cmd); +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/lapack.h b/dist/include/mlx/backend/cpu/lapack.h new file mode 100644 index 0000000..1c3ba1a --- /dev/null +++ b/dist/include/mlx/backend/cpu/lapack.h @@ -0,0 +1,80 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#define LAPACK_COMPLEX_CUSTOM +#define lapack_complex_float std::complex +#define lapack_complex_double std::complex +#define lapack_complex_float_real(z) ((z).real()) +#define lapack_complex_float_imag(z) ((z).imag()) +#define lapack_complex_double_real(z) ((z).real()) +#define lapack_complex_double_imag(z) ((z).imag()) + +#ifdef MLX_USE_ACCELERATE +#include +#else +#include +#include +#endif + +#if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME) + +// This is to work around a change in the function signatures of lapack >= 3.9.1 +// where functions taking char* also include a strlen argument, see a similar +// change in OpenCV: +// https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57 +#define MLX_LAPACK_FUNC(f) LAPACK_##f + +#else + +#define MLX_LAPACK_FUNC(f) f##_ + +#endif + +#define INSTANTIATE_LAPACK_REAL(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_REAL(geqrf) +INSTANTIATE_LAPACK_REAL(orgqr) +INSTANTIATE_LAPACK_REAL(syevd) +INSTANTIATE_LAPACK_REAL(potrf) +INSTANTIATE_LAPACK_REAL(getrf) +INSTANTIATE_LAPACK_REAL(getri) +INSTANTIATE_LAPACK_REAL(trtri) + +#define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_COMPLEX(heevd) + +#define INSTANTIATE_LAPACK_ALL(FUNC) \ + template \ + void FUNC(Args... args) { \ + if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v) { \ + MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ + } else if constexpr (std::is_same_v>) { \ + MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ + } \ + } + +INSTANTIATE_LAPACK_ALL(geev) +INSTANTIATE_LAPACK_ALL(gesdd) diff --git a/dist/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h b/dist/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h new file mode 100644 index 0000000..9505448 --- /dev/null +++ b/dist/include/mlx/backend/cpu/simd/accelerate_fp16_simd.h @@ -0,0 +1,56 @@ +#pragma once + +#include "mlx/backend/cpu/simd/base_simd.h" + +#if MLX_SIMD_LIBRARY_VERSION < 6 +#include "mlx/backend/cpu/simd/neon_fp16_simd.h" +#endif + +namespace mlx::core::simd { + +#if MLX_SIMD_LIBRARY_VERSION >= 6 +constexpr int N = 8; +template +struct ScalarT { + using v = _Float16; +}; +#endif + +template <> +inline constexpr int max_size = N; + +#define SIMD_FP16_DEFAULT_UNARY(op) \ + template <> \ + inline Simd op(Simd v) { \ + Simd in = v; \ + return op(in); \ + } + +SIMD_FP16_DEFAULT_UNARY(acos) +SIMD_FP16_DEFAULT_UNARY(acosh) +SIMD_FP16_DEFAULT_UNARY(asin) +SIMD_FP16_DEFAULT_UNARY(asinh) +SIMD_FP16_DEFAULT_UNARY(atan) +SIMD_FP16_DEFAULT_UNARY(atanh) +SIMD_FP16_DEFAULT_UNARY(cosh) +SIMD_FP16_DEFAULT_UNARY(expm1) +SIMD_FP16_DEFAULT_UNARY(log) +SIMD_FP16_DEFAULT_UNARY(log2) +SIMD_FP16_DEFAULT_UNARY(log10) +SIMD_FP16_DEFAULT_UNARY(log1p) +SIMD_FP16_DEFAULT_UNARY(sinh) +SIMD_FP16_DEFAULT_UNARY(tan) +SIMD_FP16_DEFAULT_UNARY(tanh) + +#define SIMD_FP16_DEFAULT_BINARY(op) \ + template <> \ + inline Simd op(Simd x, Simd y) { \ + Simd a = x; \ + Simd b = y; \ + return op(a, b); \ + } +SIMD_FP16_DEFAULT_BINARY(atan2) +SIMD_FP16_DEFAULT_BINARY(remainder) +SIMD_FP16_DEFAULT_BINARY(pow) + +} // namespace mlx::core::simd diff --git a/dist/include/mlx/backend/cpu/simd/accelerate_simd.h b/dist/include/mlx/backend/cpu/simd/accelerate_simd.h new file mode 100644 index 0000000..f62c67d --- /dev/null +++ b/dist/include/mlx/backend/cpu/simd/accelerate_simd.h @@ -0,0 +1,329 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "mlx/backend/cpu/simd/base_simd.h" + +// There seems to be a bug in simd/base_simd.h +// __XROS_2_0 is not defined, the expression evaluates +// to true instead of false setting the SIMD library +// higher than it should be even on macOS < 15 +#if __MAC_OS_X_VERSION_MIN_REQUIRED >= 150000 || \ + __IPHONE_OS_VERSION_MIN_REQUIRED >= 180000 || \ + __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ + __WATCH_OS_VERSION_MIN_REQUIRED >= 110000 || \ + __TV_OS_VERSION_MIN_REQUIRED >= 180000 +#define MLX_SIMD_LIBRARY_VERSION 6 +#else +#define MLX_SIMD_LIBRARY_VERSION 5 +#endif + +namespace mlx::core::simd { + +// Apple simd namespace +namespace asd = ::simd; + +// This indirection is needed to remap certain types to ones that accelerate +// SIMD can handle +template +struct ScalarT { + using v = T; +}; +template +struct ScalarT { + using v = char; +}; +template +struct ScalarT { + using v = char; +}; +template +struct ScalarT { + using v = unsigned long; +}; +template +struct ScalarT { + using v = long; +}; + +template +struct Simd { + static constexpr int size = N; + using scalar_t = typename ScalarT::v; + + Simd() {} + + template + Simd(Simd other) : value(asd::convert(other.value)) {} + + template + Simd(U v) : value(v){}; + + Simd(Simd x, Simd y) { + value = asd::make::packed_t>( + x.value, y.value); + }; + + T operator[](int idx) const { + return reinterpret_cast(&value)[idx]; + } + + T& operator[](int idx) { + return reinterpret_cast(&value)[idx]; + } + + typename asd::Vector::packed_t value; +}; + +// Values chosen based on benchmarks on M3 Max +// TODO: consider choosing these more optimally +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 8; +template <> +inline constexpr int max_size = 4; +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 16; +template <> +inline constexpr int max_size = 8; +template <> +inline constexpr int max_size = 4; +template <> +inline constexpr int max_size = 8; +template <> +inline constexpr int max_size = 4; + +#define SIMD_DEFAULT_UNARY(name, op) \ + template \ + Simd name(Simd v) { \ + return op(v.value); \ + } + +SIMD_DEFAULT_UNARY(abs, asd::abs) +SIMD_DEFAULT_UNARY(floor, asd::floor) +SIMD_DEFAULT_UNARY(acos, asd::acos) +SIMD_DEFAULT_UNARY(acosh, asd::acosh) +SIMD_DEFAULT_UNARY(asin, asd::asin) +SIMD_DEFAULT_UNARY(asinh, asd::asinh) +SIMD_DEFAULT_UNARY(atan, asd::atan) +SIMD_DEFAULT_UNARY(atanh, asd::atanh) +SIMD_DEFAULT_UNARY(ceil, asd::ceil) +SIMD_DEFAULT_UNARY(cosh, asd::cosh) +SIMD_DEFAULT_UNARY(expm1, asd::expm1) +SIMD_DEFAULT_UNARY(log, asd::log) +SIMD_DEFAULT_UNARY(log2, asd::log2) +SIMD_DEFAULT_UNARY(log10, asd::log10) +SIMD_DEFAULT_UNARY(log1p, asd::log1p) +SIMD_DEFAULT_UNARY(rint, asd::rint) +SIMD_DEFAULT_UNARY(sinh, asd::sinh) +SIMD_DEFAULT_UNARY(sqrt, asd::sqrt) +SIMD_DEFAULT_UNARY(rsqrt, asd::rsqrt) +SIMD_DEFAULT_UNARY(recip, asd::recip) +SIMD_DEFAULT_UNARY(tan, asd::tan) +SIMD_DEFAULT_UNARY(tanh, asd::tanh) + +template +Simd operator-(Simd v) { + return -v.value; +} + +template +Simd operator~(Simd v) { + return ~v.value; +} + +template +Simd isnan(Simd v) { + return asd::convert(v.value != v.value); +} + +// No simd_boolN in accelerate, use int8_t instead +template +Simd operator!(Simd v) { + return asd::convert(!v.value); +} + +#define SIMD_DEFAULT_BINARY(OP) \ + template \ + Simd operator OP(Simd x, U y) { \ + return asd::convert::scalar_t>(x.value OP y); \ + } \ + template \ + Simd operator OP(T1 x, Simd y) { \ + return asd::convert::scalar_t>(x OP y.value); \ + } \ + template \ + Simd operator OP(Simd x, Simd y) { \ + return asd::convert::scalar_t>(x.value OP y.value); \ + } + +SIMD_DEFAULT_BINARY(+) +SIMD_DEFAULT_BINARY(-) +SIMD_DEFAULT_BINARY(/) +SIMD_DEFAULT_BINARY(*) +SIMD_DEFAULT_BINARY(<<) +SIMD_DEFAULT_BINARY(>>) +SIMD_DEFAULT_BINARY(|) +SIMD_DEFAULT_BINARY(^) +SIMD_DEFAULT_BINARY(&) +SIMD_DEFAULT_BINARY(&&) +SIMD_DEFAULT_BINARY(||) + +#define SIMD_DEFAULT_COMPARISONS(OP) \ + template \ + Simd operator OP(Simd a, U b) { \ + return asd::convert(a.value OP b); \ + } \ + template \ + Simd operator OP(T a, Simd b) { \ + return asd::convert(a OP b.value); \ + } \ + template \ + Simd operator OP(Simd a, Simd b) { \ + return asd::convert(a.value OP b.value); \ + } + +SIMD_DEFAULT_COMPARISONS(>) +SIMD_DEFAULT_COMPARISONS(<) +SIMD_DEFAULT_COMPARISONS(>=) +SIMD_DEFAULT_COMPARISONS(<=) +SIMD_DEFAULT_COMPARISONS(==) +SIMD_DEFAULT_COMPARISONS(!=) + +template +Simd clz(Simd x) { + auto a = *(uint32x4_t*)(&x); + auto b = *((uint32x4_t*)(&x) + 1); + a = vclzq_u32(a); + b = vclzq_u32(b); + return asd::make_uint8(a, b); +} + +template +Simd atan2(Simd a, Simd b) { + return asd::atan2(a.value, b.value); +} + +template +Simd maximum(Simd a, Simd b) { + auto out = Simd(asd::max(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; +} + +template +Simd minimum(Simd a, Simd b) { + auto out = Simd(asd::min(a.value, b.value)); + if constexpr (!std::is_integral_v) { + out = select(isnan(b), b, select(isnan(a), a, out)); + } + return out; +} + +template +Simd remainder(Simd a, Simd b) { + Simd r; + if constexpr (!std::is_integral_v) { + r = asd::remainder(a.value, b.value); + } else { + r = a - b * (a / b); + } + if constexpr (std::is_signed_v) { + auto mask = r != 0 && (r < 0 != b < 0); + r = select(mask, r + b, r); + } + return r; +} + +template +Simd select(Simd mask, Simd x, Simd y) { + static_assert(std::is_same_v); + if constexpr (sizeof(T1) == 1) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else if constexpr (sizeof(T1) == 2) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else if constexpr (sizeof(T1) == 4) { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } else { + return asd::bitselect(y.value, x.value, asd::convert(mask.value)); + } +} + +template +Simd pow(Simd base, Simd exp) { + if constexpr (!std::is_integral_v) { + return asd::pow(base.value, exp.value); + } else { + Simd res = 1; + // Raising an integer to a negative power is undefined + if (any(exp < 0)) { + return 0; + } + while (any(exp > 0)) { + res = select((exp & 1) != 0, res * base, res); + base = select(exp > 0, base * base, base); + exp = exp >> 1; + } + return res; + } +} + +template +Simd clamp(Simd v, Simd min, Simd max) { + return asd::clamp(v.value, min.value, max.value); +} + +template +Simd fma(Simd x, Simd y, U z) { + return asd::muladd(x.value, y.value, Simd(z).value); +} + +// Reductions + +template +bool all(Simd x) { + return asd::all(x.value); +} +template +bool any(Simd x) { + return asd::any(x.value); +} +template +T sum(Simd x) { + return asd::reduce_add(x.value); +} +template +T max(Simd x) { + return asd::reduce_max(x.value); +} +template +T min(Simd x) { + return asd::reduce_min(x.value); +} + +template +T prod(Simd x) { + auto ptr = (T*)&x; + auto lhs = load(ptr); + auto rhs = load(ptr + N / 2); + return prod(lhs * rhs); +} + +} // namespace mlx::core::simd + +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#include "mlx/backend/cpu/simd/accelerate_fp16_simd.h" +#endif diff --git a/dist/include/mlx/backend/cpu/simd/base_simd.h b/dist/include/mlx/backend/cpu/simd/base_simd.h new file mode 100644 index 0000000..fc9fbbf --- /dev/null +++ b/dist/include/mlx/backend/cpu/simd/base_simd.h @@ -0,0 +1,295 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace mlx::core::simd { +template +struct Simd; + +template +static constexpr int max_size = 1; + +template +struct Simd { + static constexpr int size = 1; + T value; + Simd() {} + template + Simd(Simd v) : value(v.value) {} + template + Simd(U v) : value(v) {} +}; + +template +Simd load(const T* x) { + return *(Simd*)x; +} + +template +void store(T* dst, Simd x) { + // Maintain invariant that bool is either 0 or 1 as + // simd comparison ops set all bits in the result to 1 + if constexpr (std::is_same_v && N > 1) { + x = x & 1; + } + *(Simd*)dst = x; +} + +template +constexpr bool is_complex = false; + +template +constexpr bool is_complex().real())>> = + true; + +template +Simd rint(Simd in) { + if constexpr (is_complex) { + return Simd{ + T{std::rint(in.value.real()), std::rint(in.value.imag())}}; + } else { + return Simd{std::rint(in.value)}; + } +} + +template +Simd rsqrt(Simd in) { + return T(1.0) / sqrt(in); +} + +template +Simd recip(Simd in) { + return T(1.0) / in; +} + +#define DEFAULT_UNARY(name, op) \ + template \ + Simd name(Simd in) { \ + return op(in.value); \ + } + +DEFAULT_UNARY(operator-, std::negate{}) +DEFAULT_UNARY(operator!, std::logical_not{}) +DEFAULT_UNARY(abs, std::abs) +DEFAULT_UNARY(acos, std::acos) +DEFAULT_UNARY(acosh, std::acosh) +DEFAULT_UNARY(asin, std::asin) +DEFAULT_UNARY(asinh, std::asinh) +DEFAULT_UNARY(atan, std::atan) +DEFAULT_UNARY(atanh, std::atanh) +DEFAULT_UNARY(ceil, std::ceil) +DEFAULT_UNARY(conj, std::conj) +DEFAULT_UNARY(cosh, std::cosh) +DEFAULT_UNARY(expm1, std::expm1) +DEFAULT_UNARY(floor, std::floor) +DEFAULT_UNARY(log, std::log) +DEFAULT_UNARY(log10, std::log10) +DEFAULT_UNARY(sinh, std::sinh) +DEFAULT_UNARY(sqrt, std::sqrt) +DEFAULT_UNARY(tan, std::tan) +DEFAULT_UNARY(tanh, std::tanh) + +template +Simd log1p(Simd in) { + if constexpr (is_complex) { + auto x = in.value.real(); + auto y = in.value.imag(); + auto zabs = std::abs(in.value); + auto theta = std::atan2(y, x + 1); + if (zabs < 0.5) { + auto r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return Simd{T{x, theta}}; + } + return Simd{T{((typeof(x))(0.5)) * std::log1p(r), theta}}; + } else { + auto z0 = std::hypot(x + 1, y); + return Simd{T{std::log(z0), theta}}; + } + } else { + return Simd{std::log1p(in.value)}; + } +} + +template +Simd log2(Simd in) { + if constexpr (is_complex) { + auto out = std::log(in.value); + auto scale = decltype(out.real())(M_LN2); + return Simd{T{out.real() / scale, out.imag() / scale}}; + } else { + return Simd{std::log2(in.value)}; + } +} + +template +Simd operator~(Simd in) { + return ~in.value; +} + +template +auto real(Simd in) -> Simd { + return std::real(in.value); +} +template +auto imag(Simd in) -> Simd { + return std::imag(in.value); +} +template +Simd isnan(Simd in) { + return std::isnan(in.value); +} + +#define DEFAULT_BINARY(OP) \ + template \ + auto operator OP(Simd a, Simd b) \ + ->Simd { \ + return a.value OP b.value; \ + } \ + template \ + auto operator OP(T1 a, Simd b)->Simd { \ + return a OP b.value; \ + } \ + template \ + auto operator OP(Simd a, T2 b)->Simd { \ + return a.value OP b; \ + } + +DEFAULT_BINARY(+) +DEFAULT_BINARY(-) +DEFAULT_BINARY(*) +DEFAULT_BINARY(/) +DEFAULT_BINARY(<<) +DEFAULT_BINARY(>>) +DEFAULT_BINARY(|) +DEFAULT_BINARY(^) +DEFAULT_BINARY(&) +DEFAULT_BINARY(&&) +DEFAULT_BINARY(||) + +template +Simd clz(Simd x_) { + return __builtin_clz(x_.value); +} + +template +Simd remainder(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + T r; + if constexpr (std::is_integral_v) { + r = a % b; + } else { + r = std::remainder(a, b); + } + if constexpr (std::is_signed_v) { + if (r != 0 && (r < 0 != b < 0)) { + r += b; + } + } + return r; +} + +template +Simd maximum(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + if constexpr (!std::is_integral_v) { + if (std::isnan(a)) { + return a; + } + } + return (a > b) ? a : b; +} + +template +Simd minimum(Simd a_, Simd b_) { + T a = a_.value; + T b = b_.value; + if constexpr (!std::is_integral_v) { + if (std::isnan(a)) { + return a; + } + } + return (a < b) ? a : b; +} + +template +Simd pow(Simd a, Simd b) { + T base = a.value; + T exp = b.value; + if constexpr (!std::is_integral_v) { + return std::pow(base, exp); + } else { + T res = 1; + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } +} + +template +Simd atan2(Simd a, Simd b) { + return std::atan2(a.value, b.value); +} + +#define DEFAULT_COMPARISONS(OP) \ + template \ + Simd operator OP(Simd a, Simd b) { \ + return a.value OP b.value; \ + } \ + template \ + Simd operator OP(T1 a, Simd b) { \ + return a OP b.value; \ + } \ + template \ + Simd operator OP(Simd a, T2 b) { \ + return a.value OP b; \ + } + +DEFAULT_COMPARISONS(>) +DEFAULT_COMPARISONS(<) +DEFAULT_COMPARISONS(>=) +DEFAULT_COMPARISONS(<=) +DEFAULT_COMPARISONS(==) +DEFAULT_COMPARISONS(!=) + +template +Simd select(Simd mask, Simd x, Simd y) { + return mask.value ? x.value : y.value; +} + +template +Simd clamp(Simd v, Simd min, Simd max) { + return std::clamp(v.value, min.value, max.value); +} + +template +Simd fma(Simd x, Simd y, U z) { + return std::fma(x.value, y.value, Simd(z).value); +} + +// Reductions +#define DEFAULT_REDUCTION(name, type) \ + template \ + type name(Simd x) { \ + return x.value; \ + } + +DEFAULT_REDUCTION(max, T) +DEFAULT_REDUCTION(min, T) +DEFAULT_REDUCTION(sum, T) +DEFAULT_REDUCTION(prod, T) +DEFAULT_REDUCTION(any, bool) +DEFAULT_REDUCTION(all, bool) + +} // namespace mlx::core::simd diff --git a/dist/include/mlx/backend/cpu/simd/math.h b/dist/include/mlx/backend/cpu/simd/math.h new file mode 100644 index 0000000..f9fc831 --- /dev/null +++ b/dist/include/mlx/backend/cpu/simd/math.h @@ -0,0 +1,193 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/cpu/simd/type.h" + +namespace mlx::core::simd { + +constexpr float inf = std::numeric_limits::infinity(); + +/** + * Compute exp(x) in an optimizer friendly way as follows: + * + * First change the problem to computing 2**y where y = x / ln(2). + * + * Now we will compute 2**y as 2**y1 * 2**y2 where y1 is the integer part + * `ipart` and y2 is fractional part. For the integer part we perform bit + * shifting and for the fractional part we use a polynomial approximation. + * + * The algorithm and constants of the polynomial taken from + * https://github.com/akohlmey/fastermath/blob/master/src/exp.c which took them + * from Cephes math library. + * + * Note: The implementation below is a general fast exp. There could be faster + * implementations for numbers strictly < 0. + */ +template +Simd exp(Simd in) { + if constexpr (is_complex) { + return Simd{std::exp(in.value)}; + } else { + Simd x_init = in; + auto x = x_init * 1.442695f; // multiply with log_2(e) + Simd ipart, fpart; + ipart = floor(x + 0.5); + fpart = x - ipart; + + x = 1.535336188319500e-4f; + x = fma(x, fpart, 1.339887440266574e-3f); + x = fma(x, fpart, 9.618437357674640e-3f); + x = fma(x, fpart, 5.550332471162809e-2f); + x = fma(x, fpart, 2.402264791363012e-1f); + x = fma(x, fpart, 6.931472028550421e-1f); + x = fma(x, fpart, 1.000000000000000f); + + // generate 2**ipart in the floating point representation using integer + // bitshifting + Simd epart = (Simd(ipart) + 127) << 23; + + // Deal with NaN and Inf + auto result = select(isnan(x_init), x_init, (*(Simd*)&epart) * x); + result = select(x_init > 88.0f, Simd(inf), result); + result = select(x_init < -88.0f, Simd(0), result); + return Simd(result); + } +} + +/* Implementation from: + * https://github.com/JishinMaster/simd_utils/blob/3c1433a86fb38edcc9b02039f3c9a65b16640976/neon_mathfun.h#L357 + * which originally came from the Cephes math library. + */ +template +Simd sincos(Simd in) { + auto sign_mask_sin = in < 0; + in = abs(in); + Simd x = in; + + // scale by 4/Pi + auto y = x * 1.27323954473516f; + + // store the integer part of y in mm0 + Simd emm2 = y; + + // j=(j+1) & (~1) (see the cephes sources) + emm2 = emm2 + 1; + emm2 = emm2 & ~1; + + y = emm2; + + // Get the polynom selection mask. There is one polynom for 0 <= x <= Pi/4 + // and another one for Pi/4(-0.78515625f), x); + x = fma(y, Simd(-2.4187564849853515625e-4f), x); + x = fma(y, Simd(-3.77489497744594108e-8f), x); + + sign_mask_sin = sign_mask_sin ^ ((emm2 & 4) != 0); + auto sign_mask_cos = ((emm2 - 2) & 4) != 0; + + // Evaluate the first polynom (0 <= x <= Pi/4) in y1, + // and the second polynom (Pi/4 <= x <= 0) in y2 + auto z = x * x; + + auto y1 = + fma(z, Simd(2.443315711809948e-5f), -1.388731625493765e-3f); + auto y2 = fma(z, Simd(-1.9515295891e-4f), 8.3321608736e-3f); + y1 = fma(y1, z, 4.166664568298827e-2f); + y2 = fma(y2, z, -1.6666654611e-1f); + y1 = y1 * z; + y2 = y2 * z; + y1 = y1 * z; + y2 = fma(x, y2, x); + y1 = fma(z, Simd(-0.5f), y1); + y1 = y1 + 1.0f; + + if constexpr (Sine) { + auto ys = select(poly_mask, y1, y2); + return select(sign_mask_sin, -ys, ys); + } else { + auto yc = select(poly_mask, y2, y1); + return select(sign_mask_cos, yc, -yc); + } +} + +template +Simd sin(Simd x) { + if constexpr (is_complex) { + return std::sin(x.value); + } else { + return sincos(x); + } +} + +template +Simd cos(Simd x) { + if constexpr (is_complex) { + return std::cos(x.value); + } else { + return sincos(x); + } +} + +template +Simd erf(Simd x) { + // https://github.com/pytorch/pytorch/blob/abf28982a8cb43342e7669d859de9543fd804cc9/aten/src/ATen/cpu/vec/vec256/vec256_float.h#L175 + Simd v = x; + auto t = recip(fma(Simd(0.3275911f), abs(v), 1.0f)); + auto r = fma(Simd(1.061405429f), t, -1.453152027f); + r = fma(r, t, 1.421413741f); + r = fma(r, t, -0.284496736f); + r = fma(r, t, 0.254829592f); + auto e = -exp(-v * v); + auto result = Simd(fma(e * t, r, 1.0f)); + return select(x > 0, result, -result); +} + +template +Simd erfinv(Simd a_) { + Simd a = a_; + auto t = fma(a, 0.0f - a, 1.0f); + t = log(t); + auto lhs = [](auto t) { + Simd p; + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + return fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + }; + auto rhs = [](auto t) { + Simd p; + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + return fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + }; + auto thresh = 6.125f; + // Compute both branches and select if N > 1 + if constexpr (N == 1) { + if ((abs(t) > thresh).value) { // maximum ulp error = 2.35793 + return a * lhs(t); + } else { // maximum ulp error = 2.35002 + return a * rhs(t); + } + } else { + return a * select(abs(t) > thresh, lhs(t), rhs(t)); + } +} + +} // namespace mlx::core::simd diff --git a/dist/include/mlx/backend/cpu/simd/neon_fp16_simd.h b/dist/include/mlx/backend/cpu/simd/neon_fp16_simd.h new file mode 100644 index 0000000..5d32042 --- /dev/null +++ b/dist/include/mlx/backend/cpu/simd/neon_fp16_simd.h @@ -0,0 +1,212 @@ +#pragma once + +#include + +#include "mlx/backend/cpu/simd/base_simd.h" + +namespace mlx::core::simd { + +constexpr int N = 8; + +template <> +struct Simd { + static constexpr int size = N; + using scalar_t = float16_t; + + Simd() {} + + template + Simd(U v) : value(vdupq_n_f16(v)){}; + + Simd(float16x8_t v) : value(v){}; + + Simd(Simd other) { + auto f32x4_a = *(float32x4_t*)(&other); + auto f32x4_b = *((float32x4_t*)(&other) + 1); + value = vcvt_high_f16_f32(vcvt_f16_f32(f32x4_a), f32x4_b); + }; + + Simd(Simd other) { + value = vcvtq_f16_u16(*(uint16x8_t*)(&other.value)); + }; + + operator Simd() { + auto v = vcvtq_s16_f16(value); + return load((int16_t*)&v); + }; + + operator Simd() { + float32x4x2_t v; + v.val[0] = vcvt_f32_f16(*(float16x4_t*)(&value)); + v.val[1] = vcvt_high_f32_f16(value); + return load((float*)&v); + } + float16_t operator[](int idx) const { + return reinterpret_cast(&value)[idx]; + } + + float16_t& operator[](int idx) { + return reinterpret_cast(&value)[idx]; + } + + float16x8_t value; +}; + +#define DEFINE_NEON_UNARY_OP(name, op) \ + inline Simd name(Simd a) { \ + return Simd{op(a.value)}; \ + } + +DEFINE_NEON_UNARY_OP(abs, vabsq_f16) +DEFINE_NEON_UNARY_OP(ceil, vrndpq_f16) +DEFINE_NEON_UNARY_OP(floor, vrndmq_f16) +DEFINE_NEON_UNARY_OP(sqrt, vsqrtq_f16) +DEFINE_NEON_UNARY_OP(rsqrt, vrsqrteq_f16) +DEFINE_NEON_UNARY_OP(recip, vrecpeq_f16) +DEFINE_NEON_UNARY_OP(rint, vrndnq_f16) + +#define DEFINE_NEON_BINARY_OP(name, op) \ + inline Simd name(Simd a, Simd b) { \ + return op(a.value, b.value); \ + } \ + template \ + Simd name(Simd a, T b) { \ + return op(a.value, Simd(b).value); \ + } \ + template \ + Simd name(T a, Simd b) { \ + return op(Simd(a).value, b.value); \ + } + +inline Simd operator!(Simd v) { + auto out = vceqzq_f16(v.value); + return Simd(*(uint16_t*)&out); +} + +inline Simd operator-(Simd v) { + return vnegq_f16(v.value); +} + +DEFINE_NEON_BINARY_OP(maximum, vmaxq_f16) +DEFINE_NEON_BINARY_OP(minimum, vminq_f16) +DEFINE_NEON_BINARY_OP(operator+, vaddq_f16) +DEFINE_NEON_BINARY_OP(operator-, vsubq_f16) +DEFINE_NEON_BINARY_OP(operator*, vmulq_f16) +DEFINE_NEON_BINARY_OP(operator/, vdivq_f16) + +#define DEFINE_NEON_COMPARISON(Op, op) \ + template \ + Simd operator Op(Simd a, T b) { \ + auto out = op(a.value, Simd(b).value); \ + return Simd(*(uint16_t*)(&out)); \ + } \ + template \ + Simd operator Op(T a, Simd b) { \ + auto out = op(Simd(a).value, b.value); \ + return Simd(*(uint16_t*)(&out)); \ + } \ + inline Simd operator Op( \ + Simd a, Simd b) { \ + auto out = op(a.value, b.value); \ + return Simd(*(uint16_t*)(&out)); \ + } + +DEFINE_NEON_COMPARISON(==, vceqq_f16) +DEFINE_NEON_COMPARISON(>=, vcgeq_f16) +DEFINE_NEON_COMPARISON(<=, vcleq_f16) +DEFINE_NEON_COMPARISON(>, vcgtq_f16) +DEFINE_NEON_COMPARISON(<, vcltq_f16) + +template +Simd operator!=(Simd a, T b) { + return !(a == b); +} +template +Simd operator!=(T a, Simd b) { + return !(a == b); +} +inline Simd operator!=(Simd a, Simd b) { + return !(a == b); +} + +inline Simd operator||( + Simd a, + Simd b) { + return Simd((a != 0) || (b != 0)); +} +template +Simd operator||(Simd a, T b) { + return Simd((a != 0) || (b != 0)); +} +template +Simd operator||(T a, Simd b) { + return Simd((a != 0) || (b != 0)); +} +inline Simd operator&&( + Simd a, + Simd b) { + return Simd((a != 0) && (b != 0)); +} +template +Simd operator&&(Simd a, T b) { + return Simd((a != 0) && (b != 0)); +} +template +Simd operator&&(T a, Simd b) { + return Simd((a != 0) && (b != 0)); +} + +template <> +inline Simd isnan(Simd v) { + return v != v; +} + +template <> +inline Simd +clamp(Simd v, Simd min, Simd max) { + return minimum(maximum(v, min), max); +} + +template +Simd fma(Simd x, Simd y, T z) { + return vfmaq_f16(x.value, y.value, Simd(z).value); +} + +template +Simd +select(Simd mask, Simd x, Simd y) { + return vbslq_f16(Simd(mask).value, x.value, y.value); +} + +// Reductions +inline float16_t max(Simd x) { + float16x4_t y; + y = vpmax_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpmax_f16(y, y); + y = vpmax_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t min(Simd x) { + float16x4_t y; + y = vpmin_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpmin_f16(y, y); + y = vpmin_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t sum(Simd x) { + float16x4_t y; + y = vpadd_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + y = vpadd_f16(y, y); + y = vpadd_f16(y, y); + return vget_lane_f16(y, 0); +} +inline float16_t prod(Simd x) { + auto hx = vmul_f16(vget_low_f16(x.value), vget_high_f16(x.value)); + auto out = hx[0]; + hx[0] *= hx[1]; + hx[0] *= hx[2]; + hx[0] *= hx[3]; + return hx[0]; +} + +} // namespace mlx::core::simd diff --git a/dist/include/mlx/backend/cpu/simd/simd.h b/dist/include/mlx/backend/cpu/simd/simd.h new file mode 100644 index 0000000..8700f24 --- /dev/null +++ b/dist/include/mlx/backend/cpu/simd/simd.h @@ -0,0 +1,4 @@ +#pragma once + +#include "mlx/backend/cpu/simd/math.h" +#include "mlx/backend/cpu/simd/type.h" diff --git a/dist/include/mlx/backend/cpu/simd/type.h b/dist/include/mlx/backend/cpu/simd/type.h new file mode 100644 index 0000000..59b6ecc --- /dev/null +++ b/dist/include/mlx/backend/cpu/simd/type.h @@ -0,0 +1,11 @@ +#pragma once + +#include "mlx/backend/cpu/simd/base_simd.h" + +#ifdef MLX_USE_ACCELERATE +#if defined(__x86_64__) +// the accelerate_simd implementation require neon -- use base implementation +#else +#include "mlx/backend/cpu/simd/accelerate_simd.h" +#endif +#endif diff --git a/dist/include/mlx/backend/cpu/slicing.h b/dist/include/mlx/backend/cpu/slicing.h new file mode 100644 index 0000000..eda3732 --- /dev/null +++ b/dist/include/mlx/backend/cpu/slicing.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +std::tuple prepare_slice( + const array& in, + const Shape& start_indices, + const Shape& strides); + +void shared_buffer_slice( + const array& in, + const Strides& out_strides, + size_t data_offset, + size_t data_size, + array& out); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/ternary.h b/dist/include/mlx/backend/cpu/ternary.h new file mode 100644 index 0000000..a27a7f2 --- /dev/null +++ b/dist/include/mlx/backend/cpu/ternary.h @@ -0,0 +1,154 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include "mlx/array.h" +#include "mlx/backend/common/ternary.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cpu/encoder.h" + +namespace mlx::core { + +template +void ternary_op_dims( + const T1* a, + const T2* b, + const T3* c, + U* out, + Op op, + const Shape& shape, + const Strides& a_strides, + const Strides& b_strides, + const Strides& c_strides, + const Strides& out_strides, + int axis) { + auto stride_a = a_strides[axis]; + auto stride_b = b_strides[axis]; + auto stride_c = c_strides[axis]; + auto stride_out = out_strides[axis]; + auto N = shape[axis]; + + for (int i = 0; i < N; i++) { + if constexpr (D > 1) { + ternary_op_dims( + a, + b, + c, + out, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + axis + 1); + } else { + *out = op(*a, *b, *c); + } + a += stride_a; + b += stride_b; + c += stride_c; + out += stride_out; + } +} + +template +void ternary_op_dispatch_dims( + const T1* a_ptr, + const T2* b_ptr, + const T3* c_ptr, + U* out_ptr, + Op op, + size_t size, + Shape& shape, + std::vector& strides) { + const auto& a_strides = strides[0]; + const auto& b_strides = strides[1]; + const auto& c_strides = strides[2]; + const auto& out_strides = strides[3]; + int ndim = shape.size(); + switch (ndim) { + case 1: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + case 2: + ternary_op_dims( + a_ptr, + b_ptr, + c_ptr, + out_ptr, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + 0); + return; + } + + ContiguousIterator a_it(shape, a_strides, ndim - 2); + ContiguousIterator b_it(shape, b_strides, ndim - 2); + ContiguousIterator c_it(shape, c_strides, ndim - 2); + auto stride = out_strides[ndim - 3]; + for (size_t elem = 0; elem < size; elem += stride) { + ternary_op_dims( + a_ptr + a_it.loc, + b_ptr + b_it.loc, + c_ptr + c_it.loc, + out_ptr + elem, + op, + shape, + a_strides, + b_strides, + c_strides, + out_strides, + ndim - 2); + a_it.step(); + b_it.step(); + c_it.step(); + } +} + +template +void ternary_op( + const array& a, + const array& b, + const array& c, + array& out, + Op op, + TernaryOpType topt) { + const T1* a_ptr = a.data(); + const T2* b_ptr = b.data(); + const T3* c_ptr = c.data(); + U* out_ptr = out.data(); + + if (topt == TernaryOpType::ScalarScalarScalar) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + } else if (topt == TernaryOpType::VectorVectorVector) { + for (size_t i = 0; i < out.size(); ++i) { + *out_ptr = op(*a_ptr, *b_ptr, *c_ptr); + a_ptr++; + b_ptr++; + c_ptr++; + out_ptr++; + } + } else { + auto [shape, strides] = collapse_contiguous_dims( + a.shape(), {a.strides(), b.strides(), c.strides(), out.strides()}); + ternary_op_dispatch_dims( + a_ptr, b_ptr, c_ptr, out_ptr, op, out.size(), shape, strides); + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/threefry.h b/dist/include/mlx/backend/cpu/threefry.h new file mode 100644 index 0000000..0fc485f --- /dev/null +++ b/dist/include/mlx/backend/cpu/threefry.h @@ -0,0 +1,21 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +namespace mlx::core::random { + +/** Applies the Threefry 2x32 hash function. + * This code is based on the Jax counter-based and splittable PRNG + * https://github.com/google/jax/blob/main/docs/jep/263-prng.md + * + * Original Threefry reference: + * http://www.thesalmons.org/john/random123/papers/random123sc11.pdf + */ +std::pair threefry2x32_hash( + const std::pair& key, + std::pair count); + +} // namespace mlx::core::random diff --git a/dist/include/mlx/backend/cpu/unary.h b/dist/include/mlx/backend/cpu/unary.h new file mode 100644 index 0000000..4fab6a7 --- /dev/null +++ b/dist/include/mlx/backend/cpu/unary.h @@ -0,0 +1,281 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/common/unary.h" +#include "mlx/backend/cpu/encoder.h" +#include "mlx/backend/cpu/simd/simd.h" +#include "mlx/utils.h" + +namespace mlx::core { + +template +void unary_op(const T* a, U* out, size_t shape, size_t stride) { + for (size_t i = 0; i < shape; i += 1) { + out[i] = Op{}(*a); + a += stride; + } +} + +template +void unary_op(const array& a, array& out, Op) { + const T* src = a.data(); + U* dst = out.data(); + auto ndim = a.ndim(); + if (a.flags().contiguous) { + auto size = a.data_size(); + constexpr int N = std::min(simd::max_size, simd::max_size); + while (size >= N) { + simd::store(dst, simd::Simd(Op{}(simd::load(src)))); + size -= N; + src += N; + dst += N; + } + while (size > 0) { + *dst = Op{}(*src); + size--; + dst++; + src++; + } + } else { + size_t shape = ndim > 0 ? a.shape().back() : 1; + size_t stride = ndim > 0 ? a.strides().back() : 1; + if (ndim <= 1) { + unary_op(src, dst, shape, stride); + return; + } + auto it = ContiguousIterator(a.shape(), a.strides(), ndim - 1); + for (size_t elem = 0; elem < a.size(); elem += shape) { + unary_op(src + it.loc, dst + elem, shape, stride); + it.step(); + } + } +} + +template +void unary(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bool_: + unary_op(a, out, op); + break; + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case bfloat16: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + } + }); +} + +template +void unary_real_fp(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bfloat16: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_real] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); +} +template +void unary_fp(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case bfloat16: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_fp] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); +} + +template +void unary_signed(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + case float16: + unary_op(a, out, op); + break; + case float32: + unary_op(a, out, op); + break; + case float64: + unary_op(a, out, op); + break; + case bfloat16: + unary_op(a, out, op); + break; + case complex64: + unary_op(a, out, op); + break; + default: + throw std::runtime_error("[Abs] Called on unsigned type"); + } + }); +} + +template +void unary_complex(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { unary_op(a, out, op); }); +} + +template +void unary_complex_to_float(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch( + [a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { unary_op(a, out, op); }); +} + +template +void unary_int(const array& a, array& out, Op op, Stream stream) { + set_unary_output_data(a, out); + auto& encoder = cpu::get_command_encoder(stream); + encoder.set_input_array(a); + encoder.set_output_array(out); + encoder.dispatch([a = array::unsafe_weak_copy(a), + out = array::unsafe_weak_copy(out), + op = op]() mutable { + switch (out.dtype()) { + case uint8: + unary_op(a, out, op); + break; + case uint16: + unary_op(a, out, op); + break; + case uint32: + unary_op(a, out, op); + break; + case uint64: + unary_op(a, out, op); + break; + case int8: + unary_op(a, out, op); + break; + case int16: + unary_op(a, out, op); + break; + case int32: + unary_op(a, out, op); + break; + case int64: + unary_op(a, out, op); + break; + default: + std::ostringstream err; + err << "[unary_int] Does not support " << out.dtype(); + throw std::runtime_error(err.str()); + } + }); +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cpu/unary_ops.h b/dist/include/mlx/backend/cpu/unary_ops.h new file mode 100644 index 0000000..b68091c --- /dev/null +++ b/dist/include/mlx/backend/cpu/unary_ops.h @@ -0,0 +1,180 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/cpu/simd/simd.h" + +namespace mlx::core::detail { + +using namespace mlx::core::simd; + +#define SINGLE() \ + template \ + T operator()(T x) { \ + return (*this)(Simd(x)).value; \ + } + +#define DEFAULT_OP(Op, op) \ + struct Op { \ + template \ + Simd operator()(Simd x) { \ + return simd::op(x); \ + } \ + SINGLE() \ + }; + +DEFAULT_OP(Abs, abs) +DEFAULT_OP(ArcCos, acos) +DEFAULT_OP(ArcCosh, acosh) +DEFAULT_OP(ArcSin, asin) +DEFAULT_OP(ArcSinh, asinh) +DEFAULT_OP(ArcTan, atan) +DEFAULT_OP(ArcTanh, atanh) +DEFAULT_OP(BitwiseInvert, operator~) +DEFAULT_OP(Ceil, ceil) +DEFAULT_OP(Conjugate, conj) +DEFAULT_OP(Cos, cos) +DEFAULT_OP(Cosh, cosh) +DEFAULT_OP(Erf, erf) +DEFAULT_OP(ErfInv, erfinv) +DEFAULT_OP(Exp, exp) +DEFAULT_OP(Expm1, expm1) +DEFAULT_OP(Floor, floor); +DEFAULT_OP(Log, log); +DEFAULT_OP(Log2, log2); +DEFAULT_OP(Log10, log10); +DEFAULT_OP(Log1p, log1p); +DEFAULT_OP(LogicalNot, operator!) +DEFAULT_OP(Negative, operator-) +DEFAULT_OP(Round, rint); +DEFAULT_OP(Sin, sin) +DEFAULT_OP(Sinh, sinh) +DEFAULT_OP(Sqrt, sqrt) +DEFAULT_OP(Rsqrt, rsqrt) +DEFAULT_OP(Tan, tan) +DEFAULT_OP(Tanh, tanh) + +struct Imag { + template + Simd operator()(Simd x) { + return simd::imag(x); + } + SINGLE() +}; + +struct Real { + template + Simd operator()(Simd x) { + return simd::real(x); + } + SINGLE() +}; + +struct Sigmoid { + template + Simd operator()(Simd x) { + auto y = 1.0f / (1.0f + simd::exp(simd::abs(x))); + return simd::select(x < Simd{0}, y, Simd{1} - y); + } + SINGLE() +}; + +struct Sign { + template + Simd operator()(Simd x) { + auto z = Simd{0}; + auto o = Simd{1}; + auto m = Simd{-1}; + if constexpr (std::is_unsigned_v) { + return simd::select(x == z, z, o); + } else if constexpr (std::is_same_v) { + return simd::select(x == z, x, Simd(x / simd::abs(x))); + } else { + return simd::select(x < z, m, simd::select(x > z, o, z)); + } + } + SINGLE() +}; + +struct Square { + template + Simd operator()(Simd x) { + return x * x; + } + SINGLE() +}; + +template +Simd fp32_from_bits(Simd x) { + return *(Simd*)(&x); +} +template +Simd fp32_to_bits(Simd x) { + return *(Simd*)(&x); +} + +struct ToFP8 { + template + Simd operator()(Simd f) { + uint32_t fp8_max = 543 << 21; + auto denorm_mask = Simd(141 << 23); + Simd f_bits; + Simd f32 = f; + f_bits = fp32_to_bits(f32); + Simd result = 0u; + auto sign = f_bits & 0x80000000; + f_bits = f_bits ^ sign; + + auto f_bits_low = + fp32_to_bits(fp32_from_bits(f_bits) + fp32_from_bits(denorm_mask)); + auto result_low = Simd(f_bits_low - denorm_mask); + + auto mant_odd = Simd((f_bits >> 20) & 1); + auto f_bits_high = f_bits + (((uint32_t)(7 - 127) << 23) + 0x7FFFF); + f_bits_high = f_bits_high + Simd(mant_odd); + + auto result_high = Simd(f_bits_high >> 20); + result = select(f_bits < (121 << 23), result_low, result_high); + + auto result_sat = Simd(0x7E); + result = select(f_bits >= fp8_max, result_sat, result); + return result | Simd(sign >> 24); + } + + template + uint8_t operator()(T x) { + return (*this)(Simd(x)).value; + } +}; + +struct FromFP8 { + template + Simd operator()(Simd x) { + auto w = Simd(x) << 24; + auto sign = w & 0x80000000; + auto nonsign = w & 0x7FFFFFFF; + + auto renorm_shift = clz(nonsign); + renorm_shift = simd::select( + renorm_shift > Simd{4}, + renorm_shift - Simd{4}, + Simd{0}); + + Simd inf_nan_mask = + (Simd(nonsign + 0x01000000) >> 8) & 0x7F800000; + auto zero_mask = Simd(nonsign - 1) >> 31; + auto result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return fp32_from_bits(result); + } + float operator()(uint8_t x) { + return (*this)(Simd(x)).value; + } +}; +} // namespace mlx::core::detail diff --git a/dist/include/mlx/backend/cuda/allocator.h b/dist/include/mlx/backend/cuda/allocator.h new file mode 100644 index 0000000..7f6ad52 --- /dev/null +++ b/dist/include/mlx/backend/cuda/allocator.h @@ -0,0 +1,89 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" +#include "mlx/backend/cuda/cuda_utils.h" + +#include +#include +#include +#include + +namespace mlx::core::cu { + +class CommandEncoder; + +using allocator::Buffer; + +// Stores cuda-managed unified memory. +struct CudaBuffer { + void* data; + size_t size; + int device; // -1 for managed +}; + +class SmallSizePool { + private: + union Block { + Block* next; + CudaBuffer buf; + }; + + Block* buffer_{nullptr}; + void* data_{nullptr}; + Block* next_free_{nullptr}; + + public: + SmallSizePool(); + ~SmallSizePool(); + + SmallSizePool(const SmallSizePool&) = delete; + SmallSizePool& operator=(const SmallSizePool&) = delete; + + CudaBuffer* malloc(); + void free(CudaBuffer* buf); + bool in_pool(CudaBuffer* buf); +}; + +class CudaAllocator : public allocator::Allocator { + public: + Buffer malloc(size_t size) override; + Buffer malloc_async(size_t size, int device, cudaStream_t stream); + void free(Buffer buffer) override; + size_t size(Buffer buffer) const override; + + size_t get_active_memory() const; + size_t get_peak_memory() const; + void reset_peak_memory(); + size_t get_memory_limit(); + size_t set_memory_limit(size_t limit); + size_t get_cache_memory() const; + size_t set_cache_limit(size_t limit); + void clear_cache(); + + private: + void cuda_free(CudaBuffer* buf); + + CudaAllocator(); + friend CudaAllocator& allocator(); + + std::mutex mutex_; + size_t memory_limit_; + size_t free_limit_; + size_t total_memory_; + size_t max_pool_size_; + BufferCache buffer_cache_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + std::vector free_streams_; + std::vector mem_pools_; + SmallSizePool scalar_pool_; +}; + +CudaAllocator& allocator(); + +Buffer malloc_async(size_t size, CommandEncoder& encoder); + +} // namespace mlx::core::cu diff --git a/dist/include/mlx/backend/cuda/conv/conv.h b/dist/include/mlx/backend/cuda/conv/conv.h new file mode 100644 index 0000000..62dc934 --- /dev/null +++ b/dist/include/mlx/backend/cuda/conv/conv.h @@ -0,0 +1,126 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/gpu/copy.h" + +namespace mlx::core { + +template +struct ConvParams { + int N; // Batch size + int C; // In channels + int O; // Out channels + int strides[NDIM]; + int padding[NDIM]; + int kernel_dilation[NDIM]; + int input_dilation[NDIM]; + int groups; + bool flip; + int in_spatial_dims[NDIM]; + int wt_spatial_dims[NDIM]; + int out_spatial_dims[NDIM]; + int64_t in_strides[NDIM + 2]; + + ConvParams( + const array& in, + const array& wt, + const array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip) + : N(in.shape(0)), + C(in.shape(-1)), + O(wt.shape(0)), + groups(groups), + flip(flip) { + std::copy_n(strides.begin(), NDIM, this->strides); + std::copy_n(padding.begin(), NDIM, this->padding); + std::copy_n(kernel_dilation.begin(), NDIM, this->kernel_dilation); + std::copy_n(input_dilation.begin(), NDIM, this->input_dilation); + std::copy_n(in.shape().begin() + 1, NDIM, this->in_spatial_dims); + std::copy_n(wt.shape().begin() + 1, NDIM, this->wt_spatial_dims); + std::copy_n(out.shape().begin() + 1, NDIM, this->out_spatial_dims); + std::copy_n(in.strides().begin(), NDIM + 2, this->in_strides); + } +}; + +void gemm_grouped_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s); + +void gemm_conv( + cu::CommandEncoder& encoder, + const array& in, + const array& wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + bool flip, + Stream s); + +inline void gemm_conv( + cu::CommandEncoder& encoder, + array in, + array wt, + array& out, + const std::vector& strides, + const std::vector& padding, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + int groups, + bool flip, + Stream s) { + if (!in.flags().row_contiguous) { + in = contiguous_copy_gpu(in, s); + encoder.add_temporary(in); + } + if (!wt.flags().row_contiguous) { + wt = contiguous_copy_gpu(wt, s); + encoder.add_temporary(wt); + } + + if (groups == 1) { + gemm_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + flip, + s); + } else { + gemm_grouped_conv( + encoder, + in, + wt, + out, + strides, + padding, + kernel_dilation, + input_dilation, + groups, + flip, + s); + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/cublas_utils.h b/dist/include/mlx/backend/cuda/cublas_utils.h new file mode 100644 index 0000000..11f3228 --- /dev/null +++ b/dist/include/mlx/backend/cuda/cublas_utils.h @@ -0,0 +1,96 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/dtype_utils.h" + +namespace mlx::core { +namespace cublas_utils { + +// Get the shared cublas preference for a device +cublasLtMatmulPreference_t get_preference(cu::Device& device); + +void* allocate_workspace(cu::CommandEncoder& encoder, size_t workspace_size); + +cublasLtMatrixLayout_t create_matrix_layout( + cudaDataType_t type, + uint64_t rows, + uint64_t cols, + bool transposed, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + +inline cudaDataType_t dtype_to_cublas_type(Dtype dtype, std::string_view tag) { + switch (dtype) { + case float16: + return CUDA_R_16F; + case bfloat16: + return CUDA_R_16BF; + case float32: + return CUDA_R_32F; + case float64: + return CUDA_R_64F; + case complex64: + return CUDA_C_32F; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in {}: {}.", tag, dtype_to_string(dtype))); + } +} + +} // namespace cublas_utils + +class CublasMatmulBase { + public: + virtual ~CublasMatmulBase(); + + void set_bias(cu::CommandEncoder& encoder, const array& bias); + + protected: + CublasMatmulBase() = default; + + // Common member variables shared by all matmul types + uint64_t M_; + uint64_t N_; + cudaDataType_t scale_type_; + cublasLtMatmulPreference_t pref_{nullptr}; + cublasLtHandle_t handle_{nullptr}; + cublasLtMatmulDesc_t matmul_desc_{nullptr}; + cublasLtMatrixLayout_t a_desc_{nullptr}; + cublasLtMatrixLayout_t b_desc_{nullptr}; + cublasLtMatrixLayout_t c_desc_{nullptr}; + cublasLtMatrixLayout_t out_desc_{nullptr}; + cublasLtMatmulHeuristicResult_t heuristic_; + + void init_base( + cu::Device& device, + cudaDataType_t scale_type, + cublasComputeType_t compute_type, + cudaDataType_t data_type, + cudaDataType_t output_type, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + void execute_matmul( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + const void* alpha_ptr, + const void* beta_ptr); +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/cuda.h b/dist/include/mlx/backend/cuda/cuda.h new file mode 100644 index 0000000..2c6a5c7 --- /dev/null +++ b/dist/include/mlx/backend/cuda/cuda.h @@ -0,0 +1,10 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::cu { + +/* Check if the CUDA backend is available. */ +bool is_available(); + +} // namespace mlx::core::cu diff --git a/dist/include/mlx/backend/cuda/cuda_utils.h b/dist/include/mlx/backend/cuda/cuda_utils.h new file mode 100644 index 0000000..d9d5576 --- /dev/null +++ b/dist/include/mlx/backend/cuda/cuda_utils.h @@ -0,0 +1,89 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core { + +// Throw exception if the cuda API does not succeed. +void check_cublas_error(const char* name, cublasStatus_t err); +void check_cuda_error(const char* name, cudaError_t err); +void check_cuda_error(const char* name, CUresult err); +void check_cudnn_error(const char* name, cudnnStatus_t err); + +// The macro version that prints the command that failed. +#define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd)) +#define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd)) +#define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd)) + +// Base class for RAII managed CUDA resources. +template +class CudaHandle { + public: + CudaHandle(Handle handle = nullptr) : handle_(handle) {} + + CudaHandle(CudaHandle&& other) : handle_(other.handle_) { + assert(this != &other); + other.handle_ = nullptr; + } + + ~CudaHandle() { + // Skip if there was an error to avoid throwing in the destructors + if (cudaPeekAtLastError() != cudaSuccess) { + return; + } + reset(); + } + + CudaHandle(const CudaHandle&) = delete; + CudaHandle& operator=(const CudaHandle&) = delete; + + CudaHandle& operator=(CudaHandle&& other) { + assert(this != &other); + reset(); + std::swap(handle_, other.handle_); + return *this; + } + + void reset() { + if (handle_ != nullptr) { + CHECK_CUDA_ERROR(Destroy(handle_)); + handle_ = nullptr; + } + } + + operator Handle() const { + return handle_; + } + + protected: + Handle handle_; +}; + +namespace cu { +class Device; +}; // namespace cu + +// Wrappers of CUDA resources. +class CudaGraph : public CudaHandle { + public: + using CudaHandle::CudaHandle; + explicit CudaGraph(cu::Device& device); + void end_capture(cudaStream_t stream); +}; + +class CudaGraphExec : public CudaHandle { + public: + void instantiate(cudaGraph_t graph); +}; + +class CudaStream : public CudaHandle { + public: + explicit CudaStream(cu::Device& device); +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/cudnn_utils.h b/dist/include/mlx/backend/cuda/cudnn_utils.h new file mode 100644 index 0000000..5a3d2f9 --- /dev/null +++ b/dist/include/mlx/backend/cuda/cudnn_utils.h @@ -0,0 +1,171 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device/config.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/dtype_utils.h" + +#include +#include + +namespace mlx::core { + +namespace cu { +class CommandEncoder; +} + +namespace fe = cudnn_frontend; + +#define CHECK_CUDNN_FE_ERROR(cmd) \ + do { \ + auto error = cmd; \ + if (!error.is_good()) { \ + throw std::runtime_error( \ + fmt::format("{} failed: {}.", #cmd, error.get_message())); \ + } \ + } while (0) + +// Return pointer alignment of |x|'s data. +inline uint8_t get_alignment(const array& x) { + uint8_t alignment = 1; + uintptr_t address = reinterpret_cast(gpu_ptr(x)); + for (; alignment < 32; alignment *= 2) { + if (address % (alignment * 2)) { + return alignment; + } + } + return alignment; +} + +// Convert the type of elements in |vec| to |T|. +template +inline std::vector convert_vector(const Vec& vec) { + return std::vector(vec.begin(), vec.end()); +} + +// Map dtype to cudnn data type. +inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) { + switch (dtype) { + case int8: + return fe::DataType_t::INT8; + case int32: + return fe::DataType_t::INT32; + case uint8: + return fe::DataType_t::UINT8; + case float16: + return fe::DataType_t::HALF; + case bfloat16: + return fe::DataType_t::BFLOAT16; + case float32: + return fe::DataType_t::FLOAT; + case float64: + return fe::DataType_t::DOUBLE; + default: + throw std::runtime_error(fmt::format( + "Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype))); + } +} + +// Return an array that can be used as map key for |vec| with size <= MAX_NDIM. +// +// There are 2 differences from the const_param util from kernel_utils.cuh: +// 1. The rest of array is filled with 0. +// 2. This util can be used in .cpp files. +template class Vec> +inline std::array vector_key(const Vec& vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + std::array result = {}; + std::copy_n(vec.begin(), vec.size(), result.begin()); + return result; +} + +// Extends cuDNN graph with helpers. +class DnnGraph : public fe::graph::Graph { + public: + DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32) + : handle_(handle) { + set_io_data_type(dtype_to_cudnn_type(io_dtype)); + set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype)); + set_compute_data_type(dtype_to_cudnn_type(compute_dtype)); + } + + // Create a cuDNN tensor description from MLX array |x|. + auto& tensor( + std::shared_ptr& attrs, + int64_t uid, + const array& x) { + set_tensor_attrs(attrs, uid, x); + return attrs; + } + auto tensor(const char* name, int64_t uid, const array& x) { + auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); + tensor(attrs, uid, x); + return attrs; + } + + // Create a cuDNN tensor description from MLX array |x|, and transpose it from + // NHWC layout to NCHW. + auto& tensor_nchw( + std::shared_ptr& attrs, + int64_t uid, + const array& x) { + set_tensor_attrs_nchw(attrs, uid, x); + return attrs; + } + auto tensor_nchw(const char* name, int64_t uid, const array& x) { + auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name)); + tensor_nchw(attrs, uid, x); + return attrs; + } + + // Create a cuDNN tensor for scalar. + auto scalar(const char* name, int64_t uid, Dtype dtype) { + return Graph::tensor(fe::graph::Tensor_attributes() + .set_name(name) + .set_uid(uid) + .set_dim({1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_is_pass_by_value(true) + .set_data_type(dtype_to_cudnn_type(dtype))); + } + + // Call this before setting notes. + fe::error_t prepare(); + // Call this after setting notes. + fe::error_t build(); + + // Add cuDNN graph to CUDA graph, using native CUDA graph API. + fe::error_t encode_graph( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack); + // Add cuDNN graph to CUDA graph, using stream capture. + fe::error_t encode_capturing( + cu::CommandEncoder& encoder, + std::unordered_map variant_pack); + + private: + void* prepare_workspace(cu::CommandEncoder& encoder); + + void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x, + const std::vector& shape, + const std::vector& strides); + void set_tensor_attrs( + std::shared_ptr& tensor, + int64_t uid, + const array& x); + void set_tensor_attrs_nchw( + std::shared_ptr& tensor, + int64_t uid, + const array& x); + + cudnnHandle_t handle_; +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/device.h b/dist/include/mlx/backend/cuda/device.h new file mode 100644 index 0000000..7d317d3 --- /dev/null +++ b/dist/include/mlx/backend/cuda/device.h @@ -0,0 +1,189 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/lru_cache.h" +#include "mlx/backend/cuda/worker.h" +#include "mlx/stream.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core::cu { + +class CommandEncoder { + public: + struct CaptureContext { + CaptureContext(CommandEncoder& enc); + ~CaptureContext(); + CudaGraph graph; + CommandEncoder& enc; + bool discard{false}; + }; + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc); + ~ConcurrentContext(); + CommandEncoder& enc; + }; + + explicit CommandEncoder(Device& d); + + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + CaptureContext capture_context() { + return CaptureContext{*this}; + } + ConcurrentContext concurrent_context() { + return ConcurrentContext{*this}; + } + + void set_input_array(const array& arr); + void set_output_array(const array& arr); + + template + void add_kernel_node( + F* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + Params&&... params) { + constexpr size_t num = sizeof...(Params); + void* ptrs[num]; + size_t i = 0; + ([&](auto&& p) { ptrs[i++] = static_cast(&p); }( + std::forward(params)), + ...); + add_kernel_node((void*)func, grid_dim, block_dim, smem_bytes, ptrs); + } + + void add_kernel_node( + CUfunction func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + void** params); + + void add_kernel_node( + void* func, + dim3 grid_dim, + dim3 block_dim, + uint32_t smem_bytes, + void** params); + + void add_graph_node(cudaGraph_t child); + + void add_temporary(const array& arr) { + temporaries_.push_back(arr.data_shared_ptr()); + } + + void add_completed_handler(std::function task); + bool needs_commit(); + void commit(); + + Device& device() { + return device_; + } + + CudaStream& stream() { + return stream_; + } + + // Wait until kernels and completion handlers are finished + void synchronize(); + + private: + void add_kernel_node(const cudaKernelNodeParams& params); + void add_kernel_node(const CUDA_KERNEL_NODE_PARAMS& params); + + struct GraphNode { + cudaGraphNode_t node; + // K = kernel + // E = empty + // () = subgraph (with metadata) + // Symbols ':', '-' are reserved as separators + std::string node_type; + std::string id; + }; + + void insert_graph_dependencies(GraphNode node); + void insert_graph_dependencies(std::vector nodes); + + Device& device_; + CudaStream stream_; + CudaGraph graph_; + Worker worker_; + char node_count_{0}; + bool in_concurrent_{false}; + std::vector from_nodes_; + std::vector to_nodes_; + std::string graph_nodes_key_; + std::string graph_deps_key_; + std::vector concurrent_nodes_; + std::vector> temporaries_; + LRUCache graph_cache_; + std::vector active_deps_; + std::vector active_outputs_; + std::unordered_map node_map_; + size_t bytes_in_graph_{0}; + bool is_graph_updatable_{true}; + int max_ops_per_graph_; + int max_mb_per_graph_; +}; + +class Device { + public: + explicit Device(int device); + ~Device(); + + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + + // Make this device the current cuda device, this method is thread-safe. + void make_current(); + + CommandEncoder& get_command_encoder(Stream s); + + int cuda_device() const { + return device_; + } + int compute_capability_major() const { + return compute_capability_major_; + } + int compute_capability_minor() const { + return compute_capability_minor_; + } + cublasLtHandle_t lt_handle() const { + return lt_; + } + cudnnHandle_t cudnn_handle() const { + return cudnn_; + } + + private: + int device_; + int compute_capability_major_; + int compute_capability_minor_; + std::string device_name_; + cublasLtHandle_t lt_; + cudnnHandle_t cudnn_; + std::unordered_map encoders_; +}; + +Device& device(mlx::core::Device device); +CommandEncoder& get_command_encoder(Stream s); + +// Return an execution policy that does not sync for result. +// Note that not all thrust APIs support async policy, confirm before using. +inline auto thrust_policy(cudaStream_t stream) { + // TODO: Connect thrust's custom allocator with mlx's allocator. + return thrust::cuda::par_nosync.on(stream); +} + +} // namespace mlx::core::cu diff --git a/dist/include/mlx/backend/cuda/device/config.h b/dist/include/mlx/backend/cuda/device/config.h new file mode 100644 index 0000000..5a34029 --- /dev/null +++ b/dist/include/mlx/backend/cuda/device/config.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +// This file is used by both CUDA kernel code and host-only C++ code. + +#pragma once + +// The maximum dimensions of shape/strides passed as kernel parameters. +#define MAX_NDIM 10 + +// All existing NVIDIA hardware has a fixed 32 warp size. Though a built-in +// warpSize variable exists, using it would prevent compile-time optimizations. +#define WARP_SIZE 32 diff --git a/dist/include/mlx/backend/cuda/event.h b/dist/include/mlx/backend/cuda/event.h new file mode 100644 index 0000000..342e6ae --- /dev/null +++ b/dist/include/mlx/backend/cuda/event.h @@ -0,0 +1,78 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/allocator.h" +#include "mlx/backend/cuda/utils.h" +#include "mlx/stream.h" + +#include + +#include +#include + +namespace mlx::core::cu { + +class Device; + +// RAII-managed move-only wrapper of cudaEvent_t. +struct CudaEventHandle : public CudaHandle { + CudaEventHandle(Device& d, int flags); + Device& device; + int flags; +}; + +// Wrapper of native cuda event. It can synchronize between GPU streams, or wait +// on GPU stream in CPU stream, but can not wait on CPU stream. +class CudaEvent { + public: + CudaEvent(Device& d, int flags); + ~CudaEvent(); + + CudaEvent(CudaEvent&&) = default; + CudaEvent& operator=(CudaEvent&&) = default; + + CudaEvent(const CudaEvent&) = delete; + CudaEvent& operator=(const CudaEvent&) = delete; + + void wait(); + void wait(cudaStream_t stream); + void record(cudaStream_t stream); + + // Return whether the recorded kernels have completed. Note that this method + // returns true if record() has not been called. + bool completed() const; + + // Internal: make sure event pool is initialized. + static void init_pool(); + + private: + CudaEventHandle event_; +}; + +// Event that can synchronize between CPU and GPU. It is much slower than +// CudaEvent so the latter should always be preferred when possible. +class AtomicEvent { + public: + using Atomic = cuda::atomic; + + AtomicEvent(); + + void wait(uint64_t value); + void wait(cudaStream_t stream, uint64_t value); + void wait(Stream s, uint64_t value); + void signal(uint64_t value); + void signal(cudaStream_t stream, uint64_t value); + void signal(Stream s, uint64_t value); + bool is_signaled(uint64_t value) const; + uint64_t value() const; + + private: + Atomic* atomic() const { + return static_cast(buf_->raw_ptr()); + } + + std::shared_ptr buf_; +}; + +} // namespace mlx::core::cu diff --git a/dist/include/mlx/backend/cuda/gemms/cublas_gemm.h b/dist/include/mlx/backend/cuda/gemms/cublas_gemm.h new file mode 100644 index 0000000..1fad45e --- /dev/null +++ b/dist/include/mlx/backend/cuda/gemms/cublas_gemm.h @@ -0,0 +1,114 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +class CublasGemm : public CublasMatmulBase { + public: + CublasGemm( + cu::Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride); + + CublasGemm( + cu::Device& device, + Dtype dtype, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride); + + // The output's descriptor is inferred from inputs by default, use this method + // for unusual output. + void set_out( + Dtype dtype, + bool transposed, + uint64_t rows, + uint64_t cols, + int64_t ld, + int32_t batch_count, + int64_t batch_stride); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha = 1.0f); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta); + + private: + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha); + + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& c, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + const Strides& c_batch_strides, + float alpha, + float beta); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* c, + float alpha = 1, + float beta = 0); +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/gemms/gemv.h b/dist/include/mlx/backend/cuda/gemms/gemv.h new file mode 100644 index 0000000..27173aa --- /dev/null +++ b/dist/include/mlx/backend/cuda/gemms/gemv.h @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/device.h" + +namespace mlx::core::cu { + +bool can_use_gemv(int M, int N, int K, bool a_transposed, bool b_transposed); + +void gemv( + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + uint32_t batch_count, + const mlx::core::Shape& batch_shape, + const mlx::core::Strides& a_batch_strides, + const mlx::core::Strides& b_batch_strides, + CommandEncoder& encoder); + +} // namespace mlx::core::cu diff --git a/dist/include/mlx/backend/cuda/jit_module.h b/dist/include/mlx/backend/cuda/jit_module.h new file mode 100644 index 0000000..ac28863 --- /dev/null +++ b/dist/include/mlx/backend/cuda/jit_module.h @@ -0,0 +1,119 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/common/utils.h" +#include "mlx/backend/cuda/device.h" +#include "mlx/backend/cuda/device/config.h" + +#include +#include +#include +#include + +#include +#include + +namespace mlx::core::cu { + +class Device; + +using KernelBuilderResult = std::tuple< + /* precompiled */ bool, + /* source code */ std::string, + /* kernel names */ std::vector>; +using KernelBuilder = std::function; + +struct KernelArgs { + void** args() { + return args_.data(); + } + + void append(const array& a) { + append(reinterpret_cast(gpu_ptr(a))); + } + + template + void append(T val) { + storage_.emplace_back(val); + append_ptr(&storage_.back()); + } + + template + void append(SmallVector vec) { + storage_.emplace_back(std::move(vec)); + append_ptr(std::get>(storage_.back()).data()); + } + + template + void append(const std::vector& vec) { + append(SmallVector(vec.begin(), vec.end())); + } + + // Make sure the arg is copied to an array with size of NDIM. + template + void append_ndim(SmallVector vec) { + if (vec.size() > NDIM) { + throw std::runtime_error( + fmt::format("ndim can not be larger than {}.", NDIM)); + } + vec.resize(NDIM); + append(std::move(vec)); + } + + void append_ptr(const void* v) { + args_.push_back(const_cast(v)); + } + + private: + std::vector args_; + + // The cuGraphAddKernelNode API requires passing pointers to arguments so + // store temporary values until the node is created. + using Arg = std::variant< + std::monostate, + CUdeviceptr, + bool, + int32_t, + uint32_t, + int64_t, + float, + SmallVector, + SmallVector, + SmallVector>; + std::deque storage_; +}; + +class JitModule { + public: + JitModule( + Device& device, + const std::string& module_name, + const KernelBuilder& builder, + bool cache); + ~JitModule(); + + JitModule(const JitModule&) = delete; + JitModule& operator=(const JitModule&) = delete; + CUfunction get_kernel( + const std::string& kernel_name, + std::function configure_kernel = nullptr); + std::pair get_kernel_and_dims( + const std::string& kernel_name, + std::function configure_kernel = nullptr); + + private: + CUmodule module_{nullptr}; + std::unordered_map> kernels_; +}; + +std::unordered_map& get_jit_module_cache(); + +JitModule& get_jit_module( + const mlx::core::Device& device, + const std::string& name, + const KernelBuilder& builder, + bool use_disk_cache = true); + +} // namespace mlx::core::cu diff --git a/dist/include/mlx/backend/cuda/lru_cache.h b/dist/include/mlx/backend/cuda/lru_cache.h new file mode 100644 index 0000000..94a96a9 --- /dev/null +++ b/dist/include/mlx/backend/cuda/lru_cache.h @@ -0,0 +1,189 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/utils.h" + +#include +#include +#include +#include + +#include + +namespace mlx::core { + +template < + typename K, + typename V, + template typename M = std::unordered_map> +class LRUCache { + public: + using value_type = std::pair; + using list_type = std::list; + using iterator = typename list_type::iterator; + using const_iterator = typename list_type::const_iterator; + using map_type = M; + + explicit LRUCache(size_t capacity) : capacity_(capacity) { + if (capacity == 0) { + throw std::runtime_error("LRUCache requires capacity > 0."); + } + } + + // Initialize with capacity read from |env_name|. + LRUCache(const char* env_name, int default_capacity) + : LRUCache(env::get_var(env_name, default_capacity)) { + if (env::get_var("MLX_ENABLE_CACHE_THRASHING_CHECK", 1)) { + env_name_ = env_name; + } + } + + size_t size() const { + return map_.size(); + } + size_t capacity() const { + return capacity_; + } + bool empty() const { + return vlist_.empty(); + } + + void resize(size_t new_capacity) { + capacity_ = new_capacity; + trim(); + } + + iterator begin() { + return vlist_.begin(); + } + const_iterator begin() const { + return vlist_.begin(); + } + iterator end() { + return vlist_.end(); + } + const_iterator end() const { + return vlist_.end(); + } + + void clear() { + map_.clear(); + vlist_.clear(); + } + + iterator find(const K& key) { + auto it = map_.find(key); + if (it == map_.end()) + return end(); + vlist_.splice(vlist_.begin(), vlist_, it->second); + return it->second; + } + + template + std::pair emplace(const K& key, U&& value) { + auto it = map_.find(key); + if (it != map_.end()) { + vlist_.splice(vlist_.begin(), vlist_, it->second); + return {it->second, false}; + } + + if (env_name_ && ++cache_misses_ > 2 * capacity_) { + throw std::runtime_error(fmt::format( + "Cache thrashing is happening, please set the environment variable " + "{} to a larger value than {} to fix degraded performance.", + env_name_, + capacity_)); + } + + vlist_.emplace_front(key, std::forward(value)); + map_[key] = vlist_.begin(); + + trim(); + + return {vlist_.begin(), true}; + } + + iterator erase(iterator pos) { + map_.erase(pos->first); + return vlist_.erase(pos); + } + + V& operator[](const K& key) { + auto it = find(key); + if (it == end()) { + it = emplace(key, V{}).first; + } + return it->second; + } + + private: + void trim() { + while (map_.size() > capacity_) { + auto last = std::prev(vlist_.end()); + map_.erase(last->first); + vlist_.pop_back(); + } + } + + const char* env_name_{nullptr}; + size_t cache_misses_{0}; + + list_type vlist_; + map_type map_; + size_t capacity_; +}; + +// Turn a POD struct into a container key by doing bytes compare. +// +// Usage: +// BytesKey key; +// key.pod = { ... }; +template +struct BytesKey { + T pod; + static_assert(std::is_standard_layout_v, "T is not POD"); + + BytesKey() { + // Make sure the paddings between members are filled with 0. + memset(&pod, 0, sizeof(T)); + } + + BytesKey(const BytesKey& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + BytesKey(BytesKey&& other) { + memcpy(&pod, &other.pod, sizeof(T)); + } + + bool operator==(const BytesKey& other) const { + auto* ptr1 = reinterpret_cast(&pod); + auto* ptr2 = reinterpret_cast(&other.pod); + return memcmp(ptr1, ptr2, sizeof(T)) == 0; + } +}; + +// Compute hash according to the bytes value of T. +template +struct BytesHash { + static_assert(std::is_standard_layout_v, "T is not POD"); + + size_t operator()(const T& pod) const { + auto* ptr = reinterpret_cast(&pod); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < sizeof(T); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return value; + } +}; + +template +using BytesKeyHashMap = std::unordered_map>; + +template +using LRUBytesKeyCache = LRUCache, V, BytesKeyHashMap>; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/quantized/cublas_qqmm.h b/dist/include/mlx/backend/cuda/quantized/cublas_qqmm.h new file mode 100644 index 0000000..0a710f6 --- /dev/null +++ b/dist/include/mlx/backend/cuda/quantized/cublas_qqmm.h @@ -0,0 +1,88 @@ +// Copyright © 2025 Apple Inc. +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/cublas_utils.h" +#include "mlx/backend/cuda/device.h" + +#include + +namespace mlx::core { + +class CublasQQMM : public CublasMatmulBase { + public: + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + Dtype out_dtype, + std::string quantization_mode); + + CublasQQMM( + cu::Device& device, + bool a_transposed, + uint64_t a_rows, + uint64_t a_cols, + int64_t lda, + bool b_transposed, + uint64_t b_rows, + uint64_t b_cols, + int64_t ldb, + int64_t ldc, + int32_t batch_count, + int64_t a_batch_stride, + int64_t b_batch_stride, + int64_t c_batch_stride, + Dtype out_dtype, + std::string quantization_mode); + + void run( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + float alpha = 1.0f); + + private: + void run_batched( + cu::CommandEncoder& encoder, + array& out, + const array& a, + const array& b, + const array& a_scale, + const array& b_scale, + const Shape& batch_shape, + const Strides& a_batch_strides, + const Strides& b_batch_strides, + float alpha); + + void execute( + cu::CommandEncoder& encoder, + void* out, + const void* a, + const void* b, + const void* a_scale, + const void* b_scale, + const void* c, + float alpha = 1, + float beta = 0); + + std::string quantization_mode_; + cublasLtMatmulMatrixScale_t a_scale_mode_; + cublasLtMatmulMatrixScale_t b_scale_mode_; + cublasLtMatmulMatrixScale_t c_scale_mode_; + cublasLtMatmulMatrixScale_t out_scale_mode_; +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/quantized/cuda_fp4.h b/dist/include/mlx/backend/cuda/quantized/cuda_fp4.h new file mode 100644 index 0000000..10df457 --- /dev/null +++ b/dist/include/mlx/backend/cuda/quantized/cuda_fp4.h @@ -0,0 +1,83 @@ +#pragma once + +struct __nv_fp8_e8m0 { + __device__ __nv_fp8_e8m0(float x) { + if (!std::isfinite(x)) { + __x = 0xFF; + return; + } + if (x < 0.0f) { + __x = 0x00; + return; + } + float le = std::log2f(x); + int n = static_cast(std::nearbyintf(le)); + + n = n < -127 ? -127 : n; + n = n > 127 ? 127 : n; + __x = static_cast(n + 127); + } + + __device__ operator float() { + if (__x == 0xFF) { + return std::numeric_limits::quiet_NaN(); + } + return std::ldexp(1.0f, static_cast(__x) - 127); + } + + uint8_t __x{0}; +}; + +struct __nv_fp4_e2m1 { + __device__ __nv_fp4_e2m1(float x) { + if (std::isnan(x)) { + __x = 0x7; + return; + } + + const uint8_t sign_bit = (std::signbit(x)) ? 0x8 : 0x0; + x = std::abs(x); + + if (x > 5.0f) { + __x = 0x7; + } else if (x >= 3.5f) { + __x = 0x6; + } else if (x > 2.5f) { + __x = 0x5; + } else if (x >= 1.75f) { + __x = 0x4; + } else if (x > 1.25f) { + __x = 0x3; + } else if (x >= 0.75f) { + __x = 0x2; + } else if (x > 0.25f) { + __x = 0x1; + } else { + __x = 0x0; + } + __x |= sign_bit; + } + + __device__ operator float() { + static const float LUT[16] = { + 0.0f, + 0.5f, + 1.0f, + 1.5f, + 2.0f, + 3.0f, + 4.0f, + 6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + + return LUT[__x]; + } + uint8_t __x{0}; +}; diff --git a/dist/include/mlx/backend/cuda/quantized/qqmm_utils.h b/dist/include/mlx/backend/cuda/quantized/qqmm_utils.h new file mode 100644 index 0000000..126cc29 --- /dev/null +++ b/dist/include/mlx/backend/cuda/quantized/qqmm_utils.h @@ -0,0 +1,30 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +// Compute padded dimensions for tiled layout +// Tiles are 128 rows × 4 columns, must allocate full tiles +inline std::pair get_padded_scale_dims(int num_rows, int num_cols) { + constexpr int rows_per_tile = 128; + constexpr int cols_per_tile = 4; + + int padded_rows = + ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile; + int padded_cols = + ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile; + + return {padded_rows, padded_cols}; +} + +void repack_scales( + const array& scales, + array& scales_tiled, + cu::CommandEncoder& enc, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/quantized/quantized.h b/dist/include/mlx/backend/cuda/quantized/quantized.h new file mode 100644 index 0000000..4f1980a --- /dev/null +++ b/dist/include/mlx/backend/cuda/quantized/quantized.h @@ -0,0 +1,45 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/backend/cuda/device.h" + +namespace mlx::core { + +void affine_quantize( + const array& w, + array& wq, + array& scales, + array& biases, + int group_size_, + int bits_, + cu::CommandEncoder& enc, + const Stream& s); + +void affine_dequantize( + const array& wq, + const array& scales, + const array& biases, + array& w, + int group_size_, + int bits_, + cu::CommandEncoder& enc, + const Stream& s); + +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s); + +void fp_dequantize( + const array& wq, + const array& scales, + array& w, + int group_size, + int bits, + cu::CommandEncoder& enc, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/utils.h b/dist/include/mlx/backend/cuda/utils.h new file mode 100644 index 0000000..b060880 --- /dev/null +++ b/dist/include/mlx/backend/cuda/utils.h @@ -0,0 +1,46 @@ +// Copyright © 2025 Apple Inc. + +// This file include utilities that are used by C++ code (i.e. .cpp files). + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/allocator.h" +#include "mlx/backend/cuda/cuda_utils.h" + +namespace mlx::core { + +template +inline uint max_occupancy_block_dim(T kernel) { + int _, block_dim; + if constexpr (std::is_same_v) { + CHECK_CUDA_ERROR( + cuOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel, 0, 0, 0)); + } else { + CHECK_CUDA_ERROR( + cudaOccupancyMaxPotentialBlockSize(&_, &block_dim, kernel)); + } + return block_dim; +} + +template +inline T* gpu_ptr(array& arr) { + return reinterpret_cast( + static_cast( + static_cast(arr.buffer().ptr())->data) + + arr.offset()); +} + +// For const array, keep constness in pointer unless it is untyped. +template +inline std::conditional_t, void*, const T*> gpu_ptr( + const array& arr) { + return gpu_ptr(const_cast(arr)); +} + +struct Dtype; + +// Convert Dtype to CUDA C++ types. +const char* dtype_to_cuda_type(const Dtype& dtype); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/cuda/worker.h b/dist/include/mlx/backend/cuda/worker.h new file mode 100644 index 0000000..8f05e7b --- /dev/null +++ b/dist/include/mlx/backend/cuda/worker.h @@ -0,0 +1,55 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/cuda/event.h" + +#include +#include +#include +#include +#include + +namespace mlx::core::cu { + +// Run tasks in worker thread, synchronized with cuda stream. +class Worker { + public: + explicit Worker(Device& d); + ~Worker(); + + Worker(const Worker&) = delete; + Worker& operator=(const Worker&) = delete; + + // Add a pending |task| that will run when consumed or commited. + void add_task(std::function task); + + // Inform worker thread to run current batches after kernels in |stream| + // finish running. + void commit(cudaStream_t stream); + + private: + static void signal(void*); + + void thread_fn(); + std::mutex mtx_; + std::condition_variable cond_; + + uint64_t committed_batch_{0}; + uint64_t signaled_batch_{0}; + + // Cuda stream and event for signaling kernel completion. + CudaStream signal_stream_; + CudaEvent signal_event_; + + bool stop_{false}; + + // Tasks are put in |pending_tasks_| first, and then moved to + // |worker_tasks_| when end_batch() is called. + using Tasks = std::vector>; + Tasks pending_tasks_; + std::map worker_tasks_; + std::thread worker_; +}; + +} // namespace mlx::core::cu diff --git a/dist/include/mlx/backend/gpu/available.h b/dist/include/mlx/backend/gpu/available.h new file mode 100644 index 0000000..476c7ac --- /dev/null +++ b/dist/include/mlx/backend/gpu/available.h @@ -0,0 +1,9 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +namespace mlx::core::gpu { + +bool is_available(); + +} // namespace mlx::core::gpu diff --git a/dist/include/mlx/backend/gpu/copy.h b/dist/include/mlx/backend/gpu/copy.h new file mode 100644 index 0000000..6e6bc79 --- /dev/null +++ b/dist/include/mlx/backend/gpu/copy.h @@ -0,0 +1,57 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/copy.h" +#include "mlx/stream.h" + +#include + +namespace mlx::core { + +// Generic copy inplace +void copy_gpu_inplace( + const array& in, + array& out, + const Shape& data_shape, + const Strides& i_strides, + const Strides& o_strides, + int64_t i_offset, + int64_t o_offset, + CopyType ctype, + const Stream& s, + std::optional dynamic_i_offset = std::nullopt, + std::optional dynamic_o_offset = std::nullopt); + +void copy_gpu(const array& src, array& out, CopyType ctype, const Stream& s); +void copy_gpu(const array& src, array& out, CopyType ctype); + +void copy_gpu_inplace( + const array& in, + array& out, + CopyType ctype, + const Stream& s); + +void copy_gpu_inplace( + const array& in, + array& out, + const Strides& i_strides, + int64_t i_offset, + CopyType ctype, + const Stream& s); + +// Fill the output with the scalar val +void fill_gpu(const array& val, array& out, const Stream& s); + +// Return a contiguous array with same shape that copies the data of |arr|. +array contiguous_copy_gpu(const array& arr, const Stream& s); + +// Copy data from |in| and transpose to |out|'s shape. +void reshape_gpu(const array& in, array& out, Stream s); + +// Like the normal ops but safe to call in eval_gpu. +array flatten_in_eval(const array& x, int start_axis, int end_axis, Stream s); +array reshape_in_eval(const array& x, Shape shape, Stream s); +array swapaxes_in_eval(const array& x, int axis1, int axis2); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/gpu/eval.h b/dist/include/mlx/backend/gpu/eval.h new file mode 100644 index 0000000..f646c2e --- /dev/null +++ b/dist/include/mlx/backend/gpu/eval.h @@ -0,0 +1,18 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/array.h" +#include "mlx/stream.h" + +namespace mlx::core::gpu { + +void new_stream(Stream stream); +void eval(array& arr); +void finalize(Stream s); +void synchronize(Stream s); + +} // namespace mlx::core::gpu diff --git a/dist/include/mlx/backend/gpu/slicing.h b/dist/include/mlx/backend/gpu/slicing.h new file mode 100644 index 0000000..596f7af --- /dev/null +++ b/dist/include/mlx/backend/gpu/slicing.h @@ -0,0 +1,36 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void slice_gpu( + const array& in, + array& out, + const Shape& start_indices, + const Shape& strides, + const Stream& s); + +void concatenate_gpu( + const std::vector& inputs, + array& out, + int axis, + const Stream& s); + +void pad_gpu( + const array& in, + const array& val, + array& out, + const std::vector& axes, + const Shape& low_pad_size, + const Stream& s); + +array compute_dynamic_offset( + const array& indices, + const Strides& strides, + const std::vector& axes, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/metal/allocator.h b/dist/include/mlx/backend/metal/allocator.h new file mode 100644 index 0000000..5e177b3 --- /dev/null +++ b/dist/include/mlx/backend/metal/allocator.h @@ -0,0 +1,79 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/allocator.h" +#include "mlx/backend/common/buffer_cache.h" +#include "mlx/backend/metal/device.h" +#include "mlx/backend/metal/resident.h" + +namespace mlx::core::metal { + +using allocator::Buffer; + +class MetalAllocator : public allocator::Allocator { + /** Allocator for Metal GPUs. */ + public: + virtual Buffer malloc(size_t size) override; + virtual void free(Buffer buffer) override; + virtual size_t size(Buffer buffer) const override; + virtual Buffer make_buffer(void* ptr, size_t size) override; + virtual void release(Buffer buffer) override; + + size_t get_active_memory() { + return active_memory_; + }; + size_t get_peak_memory() { + return peak_memory_; + }; + void reset_peak_memory() { + std::unique_lock lk(mutex_); + peak_memory_ = 0; + }; + size_t get_cache_memory() { + return buffer_cache_.cache_size(); + }; + size_t set_cache_limit(size_t limit); + size_t set_memory_limit(size_t limit); + size_t get_memory_limit(); + size_t set_wired_limit(size_t limit); + void clear_cache(); + + private: + MTL::Device* device_; + + // The size of allocations which go on the heap until it is full. This size + // is chosen because it is the actual minimum size of a buffer allocated from + // the heap, a heap can have at most heap.size() / 256 buffers. + static constexpr int small_size_ = 256; + static constexpr int heap_size_ = 1 << 20; + MTL::Heap* heap_; + MetalAllocator(); + ~MetalAllocator(); + friend MetalAllocator& allocator(); + + // Caching allocator + BufferCache buffer_cache_; + + ResidencySet residency_set_; + + // Allocation stats + size_t block_limit_; + size_t gc_limit_; + size_t active_memory_{0}; + size_t peak_memory_{0}; + size_t max_pool_size_; + size_t wired_limit_{0}; + size_t num_resources_{0}; + size_t resource_limit_{0}; + + std::mutex mutex_; +}; + +MetalAllocator& allocator(); + +} // namespace mlx::core::metal diff --git a/dist/include/mlx/backend/metal/binary.h b/dist/include/mlx/backend/metal/binary.h new file mode 100644 index 0000000..0341a2f --- /dev/null +++ b/dist/include/mlx/backend/metal/binary.h @@ -0,0 +1,33 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void binary_op_gpu( + const std::vector& inputs, + std::vector& outputs, + const char* op, + const Stream& s); + +void binary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +void binary_op_gpu_inplace( + const std::vector& inputs, + std::vector& outputs, + const char* op, + const Stream& s); + +void binary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/metal/device.h b/dist/include/mlx/backend/metal/device.h new file mode 100644 index 0000000..564d15a --- /dev/null +++ b/dist/include/mlx/backend/metal/device.h @@ -0,0 +1,283 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/device.h" + +namespace mlx::core::metal { + +using MTLFCList = + std::vector>; + +struct DeviceStream; + +struct CommandEncoder { + explicit CommandEncoder(DeviceStream& stream); + CommandEncoder(const CommandEncoder&) = delete; + CommandEncoder& operator=(const CommandEncoder&) = delete; + + struct ConcurrentContext { + ConcurrentContext(CommandEncoder& enc) : enc(enc) { + enc.concurrent_ = true; + } + ~ConcurrentContext() { + enc.concurrent_ = false; + enc.prev_outputs_.insert( + enc.concurrent_outputs_.begin(), enc.concurrent_outputs_.end()); + enc.concurrent_outputs_.clear(); + } + + private: + CommandEncoder& enc; + }; + + void set_input_array(const array& a, int idx, int64_t offset = 0); + void set_output_array(array& a, int idx, int64_t offset = 0); + void register_output_array(const array& a); + void dispatch_threadgroups(MTL::Size grid_dims, MTL::Size group_dims); + void dispatch_threads(MTL::Size grid_dims, MTL::Size group_dims); + void maybeInsertBarrier(); + void set_buffer(const MTL::Buffer* buf, int idx, int64_t offset = 0); + + void set_compute_pipeline_state(MTL::ComputePipelineState* kernel) { + enc_->setComputePipelineState(kernel); + } + + void wait_for_fence(MTL::Fence* fence) { + enc_->waitForFence(fence); + } + + void update_fence(MTL::Fence* fence) { + enc_->updateFence(fence); + } + + template >> + void set_vector_bytes(const Vec& vec, size_t nelems, int idx) { + enc_->setBytes(vec.data(), nelems * sizeof(typename Vec::value_type), idx); + } + template >> + void set_vector_bytes(const Vec& vec, int idx) { + return set_vector_bytes(vec, vec.size(), idx); + } + + template + void set_bytes(const T* v, int n, int idx) { + return enc_->setBytes(v, n * sizeof(T), idx); + } + + template + void set_bytes(const T& v, int idx) { + return enc_->setBytes(&v, sizeof(T), idx); + } + + void set_threadgroup_memory_length(size_t length, int idx) { + enc_->setThreadgroupMemoryLength(length, idx); + } + + ConcurrentContext start_concurrent() { + return ConcurrentContext(*this); + } + ~CommandEncoder(); + + // Inputs to all kernels in the encoder including temporaries + std::unordered_set& inputs() { + return all_inputs_; + }; + + // Outputs of all kernels in the encoder including temporaries + std::unordered_set& outputs() { + return all_outputs_; + }; + + void barrier(); + + private: + DeviceStream& stream_; + MTL::ComputeCommandEncoder* enc_; + bool needs_barrier_{false}; + bool concurrent_{false}; + std::unordered_set prev_outputs_; + std::unordered_set next_outputs_; + std::unordered_set concurrent_outputs_; + std::unordered_set all_inputs_; + std::unordered_set all_outputs_; +}; + +struct Fence { + Fence(MTL::Fence* fence) : fence(fence) {} + ~Fence() { + fence->release(); + } + MTL::Fence* fence; +}; + +struct DeviceStream { + DeviceStream(MTL::CommandQueue* queue) : queue(queue) {}; + ~DeviceStream() { + queue->release(); + if (buffer != nullptr) { + buffer->release(); + } + }; + MTL::CommandQueue* queue; + // A map of prior command encoder outputs to their corresponding fence + std::unordered_map> outputs; + // Used to allow thread-safe access to the outputs map + std::mutex fence_mtx; + + // Data updated between command buffers + MTL::CommandBuffer* buffer{nullptr}; + int buffer_ops{0}; + size_t buffer_sizes{0}; + + // The command encoder, fence, and temporaries are updated between command + // encoders + std::unique_ptr encoder{nullptr}; + std::shared_ptr fence; + std::vector temporaries; +}; + +class Device { + public: + Device(); + Device(const Device&) = delete; + Device& operator=(const Device&) = delete; + ~Device(); + + MTL::Device* mtl_device() { + return device_; + }; + + const std::string& get_architecture() { + return arch_; + } + + int get_architecture_gen() const { + return arch_gen_; + } + + void new_queue(int index); + + MTL::CommandQueue* get_queue(Stream stream); + + MTL::CommandBuffer* get_command_buffer(int index); + bool command_buffer_needs_commit(int index); + void commit_command_buffer(int index); + CommandEncoder& get_command_encoder(int index); + void end_encoding(int index); + + MTL::Library* get_library( + const std::string& name, + const std::string& path = ""); + + MTL::Library* get_library( + const std::string& name, + const std::function& builder); + + void clear_library(const std::string& name); + + MTL::ComputePipelineState* get_kernel( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name = "", + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + + MTL::ComputePipelineState* get_kernel( + const std::string& base_name, + const std::string& hash_name = "", + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + + MTL::ArgumentEncoder* argument_encoder( + const std::vector& arg_descs) const; + + // Record temporary arrays for the given stream index + void add_temporary(array arr, int index); + void add_temporaries(std::vector arrays, int index); + + void set_residency_set(const MTL::ResidencySet* residency_set); + + private: + DeviceStream& get_stream_(int index) { + return stream_map_.find(index)->second; + } + MTL::Library* get_library_cache_(const std::string& name); + + MTL::Library* get_library_(const std::string& name); + MTL::Library* build_library_(const std::string& source_string); + + MTL::Function* get_function_(const std::string& name, MTL::Library* mtl_lib); + + MTL::Function* get_function_( + const std::string& name, + const std::string& specialized_name, + const MTLFCList& func_consts, + MTL::Library* mtl_lib); + + MTL::LinkedFunctions* get_linked_functions_( + const std::vector& funcs); + + MTL::ComputePipelineState* get_kernel_( + const std::string& name, + const MTL::Function* mtl_function); + + MTL::ComputePipelineState* get_kernel_( + const std::string& name, + const MTL::Function* mtl_function, + const MTL::LinkedFunctions* linked_functions); + + MTL::ComputePipelineState* get_kernel_( + const std::string& base_name, + MTL::Library* mtl_lib, + const std::string& hash_name, + const MTLFCList& func_consts = {}, + const std::vector& linked_functions = {}); + + MTL::Device* device_; + std::unordered_map stream_map_; + + std::shared_mutex kernel_mtx_; + std::shared_mutex library_mtx_; + std::unordered_map library_map_; + MTL::Library* default_library_; + std::unordered_map< + MTL::Library*, + std::unordered_map> + library_kernels_; + const MTL::ResidencySet* residency_set_{nullptr}; + std::string arch_; + int arch_gen_; + int max_ops_per_buffer_; + int max_mb_per_buffer_; +}; + +Device& device(mlx::core::Device); + +std::unique_ptr> new_scoped_memory_pool(); + +inline bool is_nax_available() { + auto _check_nax = []() { + bool can_use_nax = false; + if (__builtin_available( + macOS 26.2, iOS 26.2, tvOS 26.2, visionOS 26.2, *)) { + can_use_nax = true; + } + can_use_nax &= + metal::device(mlx::core::Device::gpu).get_architecture_gen() >= 17; + return can_use_nax; + }; + static bool is_nax_available_ = _check_nax(); + return is_nax_available_; +} + +} // namespace mlx::core::metal diff --git a/dist/include/mlx/backend/metal/jit/includes.h b/dist/include/mlx/backend/metal/jit/includes.h new file mode 100644 index 0000000..f4b4db5 --- /dev/null +++ b/dist/include/mlx/backend/metal/jit/includes.h @@ -0,0 +1,57 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +namespace mlx::core::metal { + +const char* utils(); +const char* binary_ops(); +const char* unary_ops(); +const char* ternary_ops(); +const char* reduce_utils(); +const char* gather(); +const char* scatter(); +const char* masked_scatter(); + +const char* arange(); +const char* unary(); +const char* binary(); +const char* binary_two(); +const char* copy(); +const char* fft(); +const char* gather_axis(); +const char* gather_front(); +const char* hadamard(); +const char* logsumexp(); +const char* quantized_utils(); +const char* quantized(); +const char* fp_quantized(); +const char* ternary(); +const char* scan(); +const char* scatter_axis(); +const char* softmax(); +const char* sort(); +const char* reduce(); + +const char* gemm(); +const char* steel_gemm_fused(); +const char* steel_gemm_masked(); +const char* steel_gemm_splitk(); +const char* steel_gemm_gather(); +const char* steel_gemm_segmented(); +const char* conv(); +const char* steel_conv(); +const char* steel_conv_general(); +const char* gemv_masked(); +const char* steel_attention(); + +const char* gemm_nax(); +const char* steel_gemm_fused_nax(); +const char* steel_gemm_gather_nax(); + +const char* quantized_nax(); +const char* fp_quantized_nax(); + +const char* steel_attention_nax(); + +} // namespace mlx::core::metal diff --git a/dist/include/mlx/backend/metal/jit/indexing.h b/dist/include/mlx/backend/metal/jit/indexing.h new file mode 100644 index 0000000..fa141fc --- /dev/null +++ b/dist/include/mlx/backend/metal/jit/indexing.h @@ -0,0 +1,76 @@ +// Copyright © 2023-2024 Apple Inc. + +constexpr std::string_view gather_kernels = R"( +[[kernel]] void gather{0}_{3}_{6}_{7}( + const device {1}* src [[buffer(0)]], + device {1}* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant int64_t* src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], + const constant int* idx_shapes [[buffer(7)]], + const constant int64_t* idx_strides [[buffer(8)]], + const constant bool* idx_contigs [[buffer(9)]], + const constant int& idx_ndim [[buffer(10)]], + {4} + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) {{ + Indices<{2}, {3}> idxs{{ + {{ {5} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; + + return gather_impl<{1}, {2}, {3}, {6}, {7}>( + src, + out, + src_shape, + src_strides, + src_ndim, + slice_sizes, + axes, + idxs, + index, + grid_dim); +}} +)"; + +constexpr std::string_view scatter_kernels = R"( +[[kernel]] void scatter{0}_{4}_updc_{7}_nwork{8}_{9}( + const device {1}* updates [[buffer(1)]], + device mlx_atomic<{1}>* out [[buffer(2)]], + const constant int* upd_shape [[buffer(3)]], + const constant int64_t* upd_strides [[buffer(4)]], + const constant size_t& upd_ndim [[buffer(5)]], + const constant size_t& upd_size [[buffer(6)]], + const constant int* out_shape [[buffer(7)]], + const constant int64_t* out_strides [[buffer(8)]], + const constant size_t& out_ndim [[buffer(9)]], + const constant int* axes [[buffer(10)]], + const constant int* idx_shapes [[buffer(11)]], + const constant int64_t* idx_strides [[buffer(12)]], + const constant bool* idx_contigs [[buffer(13)]], + const constant int& idx_ndim [[buffer(14)]], + const constant size_t& idx_size [[buffer(15)]], + {5} + uint2 gid [[thread_position_in_grid]]) {{ + Indices<{2}, {4}> idxs{{ {{ {6} }}, idx_shapes, idx_strides, idx_contigs, idx_ndim}}; + + return scatter_impl<{1}, {2}, {3}, {4}, {7}, {8}, {9}>( + updates, + out, + upd_shape, + upd_strides, + upd_ndim, + upd_size, + out_shape, + out_strides, + out_ndim, + axes, + idx_size, + idxs, + gid); +}} +)"; + +constexpr std::string_view masked_assign_kernel = R"( +template [[host_name("{0}")]] [[kernel]] decltype(masked_assign_impl<{1}, {2}>) masked_assign_impl<{1}, {2}>; +)"; diff --git a/dist/include/mlx/backend/metal/kernels/arange.h b/dist/include/mlx/backend/metal/kernels/arange.h new file mode 100644 index 0000000..5448fe9 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/arange.h @@ -0,0 +1,9 @@ +// Copyright © 2023-2024 Apple Inc. +template +[[kernel]] void arange( + constant const T& start, + constant const T& step, + device T* out, + uint index [[thread_position_in_grid]]) { + out[index] = start + index * step; +} diff --git a/dist/include/mlx/backend/metal/kernels/atomic.h b/dist/include/mlx/backend/metal/kernels/atomic.h new file mode 100644 index 0000000..93952c2 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/atomic.h @@ -0,0 +1,345 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// Atomic utils +/////////////////////////////////////////////////////////////////////////////// + +#pragma METAL internals : enable +template +constexpr constant bool is_metal_atomic = _disjunction< + is_same, + is_same, + is_same, + is_same>::value; + +#pragma METAL internals : disable + +template +struct mlx_atomic { + atomic val; +}; + +template +struct mlx_atomic>> { + atomic val; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Native metal atomics +/////////////////////////////////////////////////////////////////////////////// + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { + atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + T expected = mlx_atomic_load_explicit(object, offset); + while (!mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val * expected, offset)) { + } +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread T* expected, + T val, + size_t offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} + +// Specialization for float since it does not atomic_fetch_min_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + float val, + size_t offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val < expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +// Specialization for float since it does not atomic_fetch_max_explicit +template <> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + float val, + size_t offset) { + float expected = mlx_atomic_load_explicit(object, offset); + while (val > expected) { + if (mlx_atomic_compare_exchange_weak_explicit( + object, &expected, val, offset)) { + return; + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Custom atomics +/////////////////////////////////////////////////////////////////////////////// + +namespace { + +template +constexpr constant uint packing_size = sizeof(uint) / sizeof(T); + +template +union uint_or_packed { + T val[packing_size]; + uint bits; +}; + +template +struct mlx_atomic_update_helper { + uint operator()(uint_or_packed init, T update, size_t elem_offset) { + Op op; + init.val[elem_offset] = op(update, init.val[elem_offset]); + return init.bits; + } +}; + +template +METAL_FUNC void mlx_atomic_update_and_store( + device mlx_atomic* object, + T update, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + + mlx_atomic_update_helper helper; + uint_or_packed expected; + expected.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + + while (Op::condition(update, expected.val[elem_offset]) && + !mlx_atomic_compare_exchange_weak_explicit( + object, + &(expected.bits), + helper(expected, update, elem_offset), + pack_offset)) { + } +} + +template +struct __None { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { +#pragma unused(b) + return a; + } +}; + +template +struct __Add { + static bool condition(T a, T b) { +#pragma unused(a) +#pragma unused(b) + return true; + } + + T operator()(T a, T b) { + return a + b; + } +}; + +template +struct __Mul { + static bool condition(T a, T b) { +#pragma unused(a) + return b != 0; + } + + T operator()(T a, T b) { + return a * b; + } +}; + +template +struct __Max { + static bool condition(T a, T b) { + return a > b; + } + + T operator()(T a, T b) { + return max(a, b); + } +}; + +template +struct __Min { + static bool condition(T a, T b) { + return a < b; + } + + T operator()(T a, T b) { + return min(a, b); + } +}; + +} // namespace + +template , bool> = true> +METAL_FUNC T +mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { + size_t pack_offset = offset / sizeof(T); + size_t elem_offset = offset % sizeof(T); + uint_or_packed packed_val; + packed_val.bits = + atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); + return packed_val.val[elem_offset]; +} + +template , bool> = true> +METAL_FUNC void +mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_and_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = __UINT32_MAX__; + identity.val[elem_offset] = val; + + atomic_fetch_and_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_or_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + size_t pack_offset = offset / packing_size; + size_t elem_offset = offset % packing_size; + uint_or_packed identity; + identity.bits = 0; + identity.val[elem_offset] = val; + + atomic_fetch_or_explicit( + &(object[pack_offset].val), identity.bits, memory_order_relaxed); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_min_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_max_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_add_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC void mlx_atomic_fetch_mul_explicit( + device mlx_atomic* object, + T val, + size_t offset) { + mlx_atomic_update_and_store>(object, val, offset); +} + +template , bool> = true> +METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( + device mlx_atomic* object, + thread uint* expected, + uint val, + size_t offset) { + return atomic_compare_exchange_weak_explicit( + &(object[offset].val), + expected, + val, + memory_order_relaxed, + memory_order_relaxed); +} diff --git a/dist/include/mlx/backend/metal/kernels/bf16.h b/dist/include/mlx/backend/metal/kernels/bf16.h new file mode 100644 index 0000000..aa3c3c7 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/bf16.h @@ -0,0 +1,16 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +typedef bfloat bfloat16_t; +inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { + return as_type(x); +} + +inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { + return as_type(x); +} diff --git a/dist/include/mlx/backend/metal/kernels/bf16_math.h b/dist/include/mlx/backend/metal/kernels/bf16_math.h new file mode 100644 index 0000000..0643fb3 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/bf16_math.h @@ -0,0 +1,380 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// Metal math for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +/* + +Following the Metal Shading Language Specification (Metal 3.1) + +"bfloat is an extended itypeing point type that only allows implicit conversion + to a type of greater itypeing point rank. While bfloat can be implicitly + converted to itype, it cannot be implicitly converted to half, and neither + itype nor half can be implicitly converted to bfloat." + +Further, as far as I can tell, the stdlib math/simd functions are not defined +for bfloat and calling with an argument of type bfloat will result in that +argument getting implicitly converted to itype which then returns an output +that is (likely) a itype which cannot be implicitly converted into a bfloat + +This leads to situations where +bfloat a = 5.0bf; +bfloat b = metal::abs(a); // this will throw an error since abs return itype +bfloat c = static_cast(metal::abs(a)); // this is fine + +For the moment, I will be adding overloaded instantiations of the math +functions to accordingly automatically handle the casting + +*/ + +#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ + \ + METAL_FUNC otype abs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acos(itype x) { \ + return static_cast(__metal_acos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype acosh(itype x) { \ + return static_cast(__metal_acosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asin(itype x) { \ + return static_cast(__metal_asin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype asinh(itype x) { \ + return static_cast(__metal_asinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atan(itype y_over_x) { \ + return static_cast( \ + __metal_atan(static_cast(y_over_x), mfast)); \ + } \ + METAL_FUNC otype atan2(itype y, itype x) { \ + return static_cast( \ + __metal_atan2(static_cast(y), static_cast(x), mfast)); \ + } \ + METAL_FUNC otype atanh(itype x) { \ + return static_cast(__metal_atanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype ceil(itype x) { \ + return static_cast(__metal_ceil(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cos(itype x) { \ + return static_cast(__metal_cos(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cosh(itype x) { \ + return static_cast(__metal_cosh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype cospi(itype x) { \ + return static_cast(__metal_cospi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype divide(itype x, itype y) { \ + return static_cast( \ + __metal_divide(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype exp(itype x) { \ + return static_cast(__metal_exp(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp10(itype x) { \ + return static_cast(__metal_exp10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype exp2(itype x) { \ + return static_cast(__metal_exp2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fabs(itype x) { \ + return static_cast(__metal_fabs(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fdim(itype x, itype y) { \ + ctype t = static_cast(x - y); \ + return static_cast(select(t, ctype(0), t < ctype(0) || x == y)); \ + } \ + METAL_FUNC otype floor(itype x) { \ + return static_cast(__metal_floor(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype fma(itype x, itype y, itype z) { \ + return static_cast(__metal_fma( \ + static_cast(x), static_cast(y), static_cast(z))); \ + } \ + METAL_FUNC otype fmax(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmin(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype fmod(itype x, itype y) { \ + return static_cast( \ + __metal_fmod(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype fract(itype x) { \ + return static_cast(__metal_fract(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype frexp(itype x, thread int& exp) { \ + return static_cast(__metal_frexp(static_cast(x), &exp)); \ + } \ + METAL_FUNC otype ldexp(itype x, int k) { \ + return static_cast(__metal_ldexp(static_cast(x), k, mfast)); \ + } \ + METAL_FUNC otype log(itype x) { \ + return static_cast(__metal_log(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log10(itype x) { \ + return static_cast(__metal_log10(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype log2(itype x) { \ + return static_cast(__metal_log2(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype max(itype x, itype y) { \ + return static_cast( \ + __metal_fmax(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype max3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmax3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype median3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmedian3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype min(itype x, itype y) { \ + return static_cast( \ + __metal_fmin(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype min3(itype x, itype y, itype z) { \ + return static_cast(__metal_fmin3( \ + static_cast(x), \ + static_cast(y), \ + static_cast(z), \ + mfast)); \ + } \ + METAL_FUNC otype nextafter(itype x, itype y) { \ + return static_cast( \ + __metal_nextafter(static_cast(x), static_cast(y))); \ + } \ + METAL_FUNC otype pow(itype x, itype y) { \ + return static_cast( \ + __metal_pow(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype powr(itype x, itype y) { \ + return static_cast( \ + __metal_powr(static_cast(x), static_cast(y), mfast)); \ + } \ + METAL_FUNC otype rint(itype x) { \ + return static_cast(__metal_rint(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype round(itype x) { \ + return static_cast(__metal_round(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype rsqrt(itype x) { \ + return static_cast(__metal_rsqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sin(itype x) { \ + return static_cast(__metal_sin(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinh(itype x) { \ + return static_cast(__metal_sinh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sinpi(itype x) { \ + return static_cast(__metal_sinpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype sqrt(itype x) { \ + return static_cast(__metal_sqrt(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tan(itype x) { \ + return static_cast(__metal_tan(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanh(itype x) { \ + return static_cast(__metal_tanh(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype tanpi(itype x) { \ + return static_cast(__metal_tanpi(static_cast(x), mfast)); \ + } \ + METAL_FUNC otype trunc(itype x) { \ + return static_cast(__metal_trunc(static_cast(x), mfast)); \ + } + +namespace metal { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_MAYBE_FAST_MATH__); + +namespace fast { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_FAST_MATH__); + +} // namespace fast + +namespace precise { + +instantiate_metal_math_funcs( + bfloat16_t, + bfloat16_t, + float, + __METAL_PRECISE_MATH__); + +} // namespace precise + +} // namespace metal + +/////////////////////////////////////////////////////////////////////////////// +// Metal simd for bfloat16 +/////////////////////////////////////////////////////////////////////////////// + +#define instantiate_metal_simd_comm_funcs( \ + itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ + \ + METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ + return ctype_to_otype( \ + __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_down( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta, ushort modulo) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_and_fill_up( \ + itype data, itype filling_data, ushort delta) { \ + return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ + itype_to_ctype(data), \ + itype_to_ctype(filling_data), \ + delta, \ + __metal_get_simdgroup_size(ushort()))); \ + } \ + \ + METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ + } \ + \ + METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ + return ctype_to_otype( \ + __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ + } + +#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ + \ + METAL_FUNC otype simd_max(itype data) { \ + return static_cast(__metal_simd_max(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_min(itype data) { \ + return static_cast(__metal_simd_min(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_exclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ + return static_cast( \ + __metal_simd_prefix_inclusive_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_product(itype data) { \ + return static_cast(__metal_simd_product(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_sum(itype data) { \ + return static_cast(__metal_simd_sum(static_cast(data))); \ + } \ + \ + METAL_FUNC otype simd_xor(itype data) { \ + return static_cast(__metal_simd_xor(static_cast(data))); \ + } + +namespace metal { + +instantiate_metal_simd_comm_funcs( + bfloat16_t, + bfloat16_t, + uint16_t, + bfloat16_to_uint16, + uint16_to_bfloat16); +instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); + +} // namespace metal diff --git a/dist/include/mlx/backend/metal/kernels/binary.h b/dist/include/mlx/backend/metal/kernels/binary.h new file mode 100644 index 0000000..f1df885 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/binary.h @@ -0,0 +1,199 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + uint index [[thread_position_in_grid]]) { + c[index] = Op()(a[0], b[0]); +} + +template ::n> +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } +} + +template ::n> +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } +} + +template ::n> +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } +} + +template ::n> +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } +} + +template ::n> +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } +} + +template ::n> +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } +} + +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + constant const int64_t& a_stride, + constant const int64_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + c[index] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + c[out_idx] = Op()(a[a_idx], b[b_idx]); +} + +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = int64_t> +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + c[out_idx++] = Op()(a[idx.x], b[idx.y]); + idx.x += a_xstride; + idx.y += b_xstride; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/binary_ops.h b/dist/include/mlx/backend/metal/kernels/binary_ops.h new file mode 100644 index 0000000..cb3e8a3 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/binary_ops.h @@ -0,0 +1,326 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +struct Add { + template + T operator()(T x, T y) { + return x + y; + } +}; + +struct FloorDivide { + template + T operator()(T x, T y) { + return x / y; + } + template <> + float operator()(float x, float y) { + return trunc(x / y); + } + template <> + half operator()(half x, half y) { + return trunc(x / y); + } + template <> + bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { + return trunc(x / y); + } +}; + +struct Divide { + template + T operator()(T x, T y) { + return x / y; + } +}; + +struct Remainder { + template + metal::enable_if_t & !metal::is_signed_v, T> + operator()(T x, T y) { + return x % y; + } + template + metal::enable_if_t & metal::is_signed_v, T> + operator()(T x, T y) { + auto r = x % y; + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template + metal::enable_if_t, T> operator()(T x, T y) { + T r = fmod(x, y); + if (r != 0 && (r < 0 != y < 0)) { + r += y; + } + return r; + } + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + return x % y; + } +}; + +struct Equal { + template + bool operator()(T x, T y) { + return x == y; + } +}; + +struct NaNEqual { + template + bool operator()(T x, T y) { + return x == y || (metal::isnan(x) && metal::isnan(y)); + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x == y || + (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && + metal::isnan(y.imag)) || + (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || + (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); + } +}; + +struct Greater { + template + bool operator()(T x, T y) { + return x > y; + } +}; + +struct GreaterEqual { + template + bool operator()(T x, T y) { + return x >= y; + } +}; + +struct Less { + template + bool operator()(T x, T y) { + return x < y; + } +}; + +struct LessEqual { + template + bool operator()(T x, T y) { + return x <= y; + } +}; + +struct LogAddExp { + template + T operator()(T x, T y) { + if (metal::isnan(x) || metal::isnan(y)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr T inf = metal::numeric_limits::infinity(); + T maxval = metal::max(x, y); + T minval = metal::min(x, y); + return (minval == -inf || maxval == inf) + ? maxval + : (maxval + log1p(metal::exp(minval - maxval))); + }; + + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || + metal::isnan(y.imag)) { + return metal::numeric_limits::quiet_NaN(); + } + constexpr float inf = metal::numeric_limits::infinity(); + complex64_t maxval = x > y ? x : y; + complex64_t minval = x < y ? x : y; + if (minval.real == -inf || maxval.real == inf) + return maxval; + float m = metal::exp(minval.real - maxval.real); + complex64_t dexp{ + m * metal::cos(minval.imag - maxval.imag), + m * metal::sin(minval.imag - maxval.imag), + }; + return maxval + log1p(dexp); + } +}; + +struct Maximum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::max(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x > y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x > y ? x : y; + } +}; + +struct Minimum { + template + metal::enable_if_t, T> operator()(T x, T y) { + return metal::min(x, y); + } + + template + metal::enable_if_t, T> operator()(T x, T y) { + if (metal::isnan(x)) { + return x; + } + return x < y ? x : y; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (metal::isnan(x.real) || metal::isnan(x.imag)) { + return x; + } + return x < y ? x : y; + } +}; + +struct Multiply { + template + T operator()(T x, T y) { + return x * y; + } +}; + +struct NotEqual { + template + bool operator()(T x, T y) { + return x != y; + } + template <> + bool operator()(complex64_t x, complex64_t y) { + return x.real != y.real || x.imag != y.imag; + } +}; + +struct Power { + template + metal::enable_if_t, T> operator()(T base, T exp) { + return metal::pow(base, exp); + } + + template + metal::enable_if_t, T> operator()(T base, T exp) { + T res = 1; + // Undefined to raise integer to negative power + if (exp < 0) { + return 0; + } + + while (exp) { + if (exp & 1) { + res *= base; + } + exp >>= 1; + base *= base; + } + return res; + } + + template <> + complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } + auto x_theta = metal::atan2(x.imag, x.real); + auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); + auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); + auto phase = y.imag * x_ln_r + y.real * x_theta; + return {mag * metal::cos(phase), mag * metal::sin(phase)}; + } +}; + +struct Subtract { + template + T operator()(T x, T y) { + return x - y; + } +}; + +struct LogicalAnd { + template + T operator()(T x, T y) { + return x && y; + }; +}; + +struct LogicalOr { + template + T operator()(T x, T y) { + return x || y; + }; +}; + +struct BitwiseAnd { + template + T operator()(T x, T y) { + return x & y; + }; +}; + +struct BitwiseOr { + template + T operator()(T x, T y) { + return x | y; + }; +}; + +struct BitwiseXor { + template + T operator()(T x, T y) { + return x ^ y; + }; +}; + +struct LeftShift { + template + T operator()(T x, T y) { + return x << y; + }; +}; + +struct RightShift { + template + T operator()(T x, T y) { + return x >> y; + }; +}; + +struct ArcTan2 { + template + T operator()(T y, T x) { + return metal::precise::atan2(y, x); + } +}; + +struct DivMod { + template + metal::array operator()(T x, T y) { + return {FloorDivide{}(x, y), Remainder{}(x, y)}; + }; +}; diff --git a/dist/include/mlx/backend/metal/kernels/binary_two.h b/dist/include/mlx/backend/metal/kernels/binary_two.h new file mode 100644 index 0000000..4455e4c --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/binary_two.h @@ -0,0 +1,244 @@ +// Copyright © 2024 Apple Inc. + +template +[[kernel]] void binary_ss( + device const T* a, + device const T* b, + device U* c, + device U* d, + uint index [[thread_position_in_grid]]) { + auto out = Op()(a[0], b[0]); + c[index] = out[0]; + d[index] = out[1]; +} + +template ::n> +[[kernel]] void binary_sv( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vs( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vv( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_sv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vs2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } +} + +template ::n> +[[kernel]] void binary_vv2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } +} + +template +[[kernel]] void binary_g_nd1( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int64_t& a_stride, + constant const int64_t& b_stride, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_stride); + auto b_idx = elem_to_loc_1(index, b_stride); + auto out = Op()(a[a_idx], b[b_idx]); + c[index] = out[0]; + d[index] = out[1]; +} + +template +[[kernel]] void binary_g_nd2( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template +[[kernel]] void binary_g_nd3( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + auto out = Op()(a[a_idx], b[b_idx]); + c[out_idx] = out[0]; + d[out_idx] = out[1]; +} + +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = int64_t> +[[kernel]] void binary_g( + device const T* a, + device const T* b, + device U* c, + device U* d, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); + auto xshape = shape[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + auto out = Op()(a[idx.x], b[idx.y]); + c[out_idx] = out[0]; + d[out_idx++] = out[1]; + idx.x += a_xstride; + idx.y += b_xstride; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/cexpf.h b/dist/include/mlx/backend/metal/kernels/cexpf.h new file mode 100644 index 0000000..b45fe6a --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/cexpf.h @@ -0,0 +1,134 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2008-2013 NVIDIA Corporation +// Copyright © 2013 Filipe RNC Maia +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Forked from +// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h + +// TODO: We should use thrust::exp but the thrust header in old CUDA versions +// can not be used in JIT. + +#pragma once + +#include + +using ieee_float_shape_type = union { + float value; + uint32_t word; +}; + +inline void get_float_word(thread uint32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void get_float_word(thread int32_t& i, float d) { + ieee_float_shape_type gf_u; + gf_u.value = (d); + (i) = gf_u.word; +} + +inline void set_float_word(thread float& d, uint32_t i) { + ieee_float_shape_type sf_u; + sf_u.word = (i); + (d) = sf_u.value; +} + +inline float frexp_expf(float x, thread int* expt) { + const uint32_t k = 235; + const float kln2 = 162.88958740F; + + float exp_x; + uint32_t hx; + + exp_x = metal::exp(x - kln2); + get_float_word(hx, exp_x); + *expt = (hx >> 23) - (0x7f + 127) + k; + set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); + return exp_x; +} + +inline complex64_t ldexp_cexpf(complex64_t z, int expt) { + float x, y, exp_x, scale1, scale2; + int ex_expt, half_expt; + + x = z.real; + y = z.imag; + exp_x = frexp_expf(x, &ex_expt); + expt += ex_expt; + + half_expt = expt / 2; + set_float_word(scale1, (0x7f + half_expt) << 23); + half_expt = expt - half_expt; + set_float_word(scale2, (0x7f + half_expt) << 23); + + return complex64_t{ + metal::cos(y) * exp_x * scale1 * scale2, + metal::sin(y) * exp_x * scale1 * scale2}; +} + +inline complex64_t cexpf(const thread complex64_t& z) { + float x, y, exp_x; + uint32_t hx, hy; + + const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; + + x = z.real; + y = z.imag; + + get_float_word(hy, y); + hy &= 0x7fffffff; + + /* cexp(x + I 0) = exp(x) + I 0 */ + if (hy == 0) { + return complex64_t{metal::exp(x), y}; + } + get_float_word(hx, x); + /* cexp(0 + I y) = cos(y) + I sin(y) */ + if ((hx & 0x7fffffff) == 0) { + return complex64_t{metal::cos(y), metal::sin(y)}; + } + if (hy >= 0x7f800000) { + if ((hx & 0x7fffffff) != 0x7f800000) { + /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ + return complex64_t{y - y, y - y}; + } else if (hx & 0x80000000) { + /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ + return complex64_t{0.0, 0.0}; + } else { + /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ + return complex64_t{x, y - y}; + } + } + + if (hx >= exp_ovfl && hx <= cexp_ovfl) { + /* + * x is between 88.7 and 192, so we must scale to avoid + * overflow in expf(x). + */ + return ldexp_cexpf(z, 0); + } else { + /* + * Cases covered here: + * - x < exp_ovfl and exp(x) won't overflow (common case) + * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 + * - x = +-Inf (generated by exp()) + * - x = NaN (spurious inexact exception from y) + */ + exp_x = metal::exp(x); + return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/complex.h b/dist/include/mlx/backend/metal/kernels/complex.h new file mode 100644 index 0000000..6e39148 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/complex.h @@ -0,0 +1,173 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +using namespace metal; + +struct complex64_t; + +template +static constexpr constant bool can_convert_to_complex64 = + !is_same_v && is_convertible_v; + +template +static constexpr constant bool can_convert_from_complex64 = + !is_same_v && + (is_convertible_v || is_convertible_v); + +struct complex64_t { + float real; + float imag; + + // Constructors + constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; + constexpr complex64_t() : real(0), imag(0) {}; + constexpr complex64_t() threadgroup : real(0), imag(0) {}; + + // Conversions to complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) thread : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) device : real(x), imag(0) {} + + template < + typename T, + typename = typename enable_if>::type> + constexpr complex64_t(T x) constant : real(x), imag(0) {} + + // Conversions from complex64_t + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const thread { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const threadgroup { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const device { + return static_cast(real); + } + + template < + typename T, + typename = typename enable_if>::type> + constexpr operator T() const constant { + return static_cast(real); + } +}; + +constexpr complex64_t operator-(complex64_t x) { + return {-x.real, -x.imag}; +} + +constexpr bool operator>=(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); +} + +constexpr bool operator>(complex64_t a, complex64_t b) { + return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); +} + +constexpr bool operator<=(complex64_t a, complex64_t b) { + return operator>=(b, a); +} + +constexpr bool operator<(complex64_t a, complex64_t b) { + return operator>(b, a); +} + +constexpr bool operator==(complex64_t a, complex64_t b) { + return a.real == b.real && a.imag == b.imag; +} + +constexpr complex64_t operator+(complex64_t a, complex64_t b) { + return {a.real + b.real, a.imag + b.imag}; +} + +constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr threadgroup complex64_t& operator+=( + threadgroup complex64_t& a, + complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) { + a.real += b.real; + a.imag += b.imag; + return a; +} + +constexpr complex64_t operator+(float a, complex64_t b) { + return {a + b.real, b.imag}; +} +constexpr complex64_t operator+(complex64_t a, float b) { + return {a.real + b, a.imag}; +} + +constexpr complex64_t operator-(complex64_t a, complex64_t b) { + return {a.real - b.real, a.imag - b.imag}; +} +constexpr complex64_t operator-(float a, complex64_t b) { + return {a - b.real, -b.imag}; +} +constexpr complex64_t operator-(complex64_t a, float b) { + return {a.real - b, a.imag}; +} + +constexpr complex64_t operator*(complex64_t a, complex64_t b) { + return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; +} + +constexpr complex64_t operator/(complex64_t a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a.real * b.real + a.imag * b.imag; + auto y = a.imag * b.real - a.real * b.imag; + return {x / denom, y / denom}; +} + +constexpr complex64_t operator/(float a, complex64_t b) { + auto denom = b.real * b.real + b.imag * b.imag; + auto x = a * b.real; + auto y = -a * b.imag; + return {x / denom, y / denom}; +} + +constexpr complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real - (b.real * static_cast(a.real / b.real)); + auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); + if (real != 0 && (real < 0 != b.real < 0)) { + real += b.real; + } + if (imag != 0 && (imag < 0 != b.imag < 0)) { + imag += b.imag; + } + return {real, imag}; +} diff --git a/dist/include/mlx/backend/metal/kernels/copy.h b/dist/include/mlx/backend/metal/kernels/copy.h new file mode 100644 index 0000000..cf22347 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/copy.h @@ -0,0 +1,276 @@ +// Copyright © 2024 Apple Inc. + +template ::n> +[[kernel]] void copy_s( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } + } +} + +template ::n> +[[kernel]] void copy_v( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } +} + +template ::n> +[[kernel]] void copy_s2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } +} + +template ::n> +[[kernel]] void copy_v2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } +} + +template +[[kernel]] void copy_g_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + dst[index] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + IdxT dst_idx = + index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_g( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto src_idx = elem_to_loc( + {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); + if (N == 1) { + IdxT dst_idx = + index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + dst[dst_idx] = static_cast(src[src_idx]); + return; + } + auto xshape = src_shape[ndim - 1]; + IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + auto src_xstride = src_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[dst_idx + i] = static_cast(src[src_idx]); + src_idx += src_xstride; + } +} + +template +[[kernel]] void copy_gg_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx] = static_cast(src[src_idx]); +} + +template +[[kernel]] void copy_gg( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + uint3 index [[thread_position_in_grid]]) { + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = static_cast(src[idx.x]); + return; + } + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = static_cast(src[idx.x]); + idx.x += src_xstride; + idx.y += dst_xstride; + } +} + +template +[[kernel]] void copy_gg_dynamic_nd1( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t& src_stride [[buffer(3)]], + constant const int64_t& dst_stride [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_1(index, src_stride); + auto dst_idx = elem_to_loc_1(index, dst_stride); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic_nd2( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint2 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_2(index, src_strides); + auto dst_idx = elem_to_loc_2(index, dst_strides); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic_nd3( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint3 index [[thread_position_in_grid]]) { + auto src_idx = elem_to_loc_3(index, src_strides); + auto dst_idx = elem_to_loc_3(index, dst_strides); + dst[dst_idx + dst_offset] = src[src_idx + src_offset]; +} + +template +[[kernel]] void copy_gg_dynamic( + device const T* src [[buffer(0)]], + device U* dst [[buffer(1)]], + constant const int* src_shape [[buffer(2)]], + constant const int64_t* src_strides [[buffer(3)]], + constant const int64_t* dst_strides [[buffer(4)]], + constant const int& ndim [[buffer(5)]], + constant const int64_t& src_offset [[buffer(6)]], + constant const int64_t& dst_offset [[buffer(7)]], + uint3 index [[thread_position_in_grid]]) { + src += src_offset; + dst += dst_offset; + auto idx = elem_to_loc_2_nd( + {N * index.x, index.y, index.z}, + src_shape, + src_strides, + dst_strides, + ndim); + if (N == 1) { + dst[idx.y] = src[idx.x]; + return; + } + IdxT src_xstride = src_strides[ndim - 1]; + IdxT dst_xstride = dst_strides[ndim - 1]; + auto xshape = src_shape[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + dst[idx.y] = src[idx.x]; + idx.x += src_xstride; + idx.y += dst_xstride; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/defines.h b/dist/include/mlx/backend/metal/kernels/defines.h new file mode 100644 index 0000000..c369adb --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/defines.h @@ -0,0 +1,24 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#if defined __METAL__ || defined MLX_METAL_JIT +#define MTL_CONST constant +#else +#define MTL_CONST +#endif + +static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; +static MTL_CONST constexpr int REDUCE_N_READS = 4; +static MTL_CONST constexpr int REDUCE_N_WRITES = 4; +static MTL_CONST constexpr int SOFTMAX_N_READS = 4; +static MTL_CONST constexpr int RMS_N_READS = 4; +static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; + +// Instantiate a templated kernel. +// Extra args are used as template parameters: +// e.g. instantiate_kernel(binary_int, binary, a, b) -> +// [[host_name(binary_int)]] [kernel] binary +#define instantiate_kernel(name, func, ...) \ + template [[host_name( \ + name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; diff --git a/dist/include/mlx/backend/metal/kernels/erf.h b/dist/include/mlx/backend/metal/kernels/erf.h new file mode 100644 index 0000000..da6c2ea --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/erf.h @@ -0,0 +1,69 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include + +/* + * Approximation to the error function. + * Based on code from: + * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 + */ +float erf(float a) { + float r, s, t, u; + t = metal::abs(a); + s = a * a; + if (t > 0.927734375f) { + // maximum error 0.99527 ulp + r = metal::fma( + -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 + u = metal::fma( + -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 + r = metal::fma(r, s, u); + r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 + r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 + r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 + r = metal::fma(r, t, -t); + // TODO, replace with expm1 when implemented + r = 1.0f - metal::exp(r); + r = metal::copysign(r, a); + } else { + // maximum error 0.98929 ulp + r = -5.96761703e-4f; // -0x1.38e000p-11 + r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 + r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 + r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 + r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 + r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 + r = metal::fma(r, a, a); + } + return r; +} + +float erfinv(float a) { + auto t = metal::fma(a, 0.0f - a, 1.0f); + t = metal::log(t); + float p; + if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 + p = 3.03697567e-10f; // 0x1.4deb44p-32 + p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 + p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 + p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 + p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 + p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 + p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 + p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 + p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 + } else { // maximum ulp error = 2.35002 + p = 5.43877832e-9f; // 0x1.75c000p-28 + p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 + p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 + p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 + p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 + p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 + p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 + p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 + p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 + p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 + } + return a * p; +} diff --git a/dist/include/mlx/backend/metal/kernels/expm1f.h b/dist/include/mlx/backend/metal/kernels/expm1f.h new file mode 100644 index 0000000..68224e1 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/expm1f.h @@ -0,0 +1,90 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +// Original license copied below: +// Copyright (c) 2015-2023 Norbert Juffa +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +/* Compute exponential base e minus 1. Maximum ulp error = 0.997458 + + i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. + Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). + With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, + when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. + + NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) +*/ +float expm1f_scaled_unchecked(float a, float b) { + float f, j, r, s, t, u, v, x, y; + int i; + + // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) + j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 + j = j - 12582912.0f; // 0x1.8p23 + i = (int)j; + f = fma(j, -6.93145752e-1f, a); + + // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] + s = f * f; + if (a == 0.0f) + s = a; // ensure -0 is passed through + // err = 0.997458 ulp1 = 11081805 + r = 1.97350979e-4f; // 0x1.9de000p-13 + r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 + r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 + r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 + r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 + r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 + u = (j == 1) ? (f + 0.5f) : f; + v = fma(r, s, u); + s = 0.5f * b; + t = ldexp(s, i); + y = t - s; + x = (t - y) - s; // double-float canonicalization of difference + r = fma(v, t, x) + y; + r = r + r; + if (j == 0) + r = v; + if (j == 1) + r = v + v; + return r; +} + +/* Compute exponential base e minus 1. max ulp err = 0.99746 */ +float expm1f(float a) { + float r; + + r = expm1f_scaled_unchecked(a, 1.0f); + /* handle severe overflow and underflow */ + if (abs(a - 1.0f) > 88.0f) { + r = pow(2, a); + r = fma(r, r, -1.0f); + } + return r; +} diff --git a/dist/include/mlx/backend/metal/kernels/fft.h b/dist/include/mlx/backend/metal/kernels/fft.h new file mode 100644 index 0000000..e478a85 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/fft.h @@ -0,0 +1,486 @@ +// Copyright © 2024 Apple Inc. + +// Metal FFT using Stockham's algorithm +// +// References: +// - VkFFT (https://github.com/DTolm/VkFFT) +// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) + +#include + +#include "mlx/backend/metal/kernels/fft/radix.h" +#include "mlx/backend/metal/kernels/fft/readwrite.h" +#include "mlx/backend/metal/kernels/steel/defines.h" + +using namespace metal; + +#define MAX_RADIX 13 +// Reached when elems_per_thread_ = 6, max_radix = 13 +// and some threads have to do 3 radix 6s requiring 18 float2s. +#define MAX_OUTPUT_SIZE 18 + +// Specialize for a particular value of N at runtime +STEEL_CONST bool inv_ [[function_constant(0)]]; +STEEL_CONST bool is_power_of_2_ [[function_constant(1)]]; +STEEL_CONST int elems_per_thread_ [[function_constant(2)]]; +// rader_m = n / rader_n +STEEL_CONST int rader_m_ [[function_constant(3)]]; +// Stockham steps +STEEL_CONST int radix_13_steps_ [[function_constant(4)]]; +STEEL_CONST int radix_11_steps_ [[function_constant(5)]]; +STEEL_CONST int radix_8_steps_ [[function_constant(6)]]; +STEEL_CONST int radix_7_steps_ [[function_constant(7)]]; +STEEL_CONST int radix_6_steps_ [[function_constant(8)]]; +STEEL_CONST int radix_5_steps_ [[function_constant(9)]]; +STEEL_CONST int radix_4_steps_ [[function_constant(10)]]; +STEEL_CONST int radix_3_steps_ [[function_constant(11)]]; +STEEL_CONST int radix_2_steps_ [[function_constant(12)]]; +// Rader steps +STEEL_CONST int rader_13_steps_ [[function_constant(13)]]; +STEEL_CONST int rader_11_steps_ [[function_constant(14)]]; +STEEL_CONST int rader_8_steps_ [[function_constant(15)]]; +STEEL_CONST int rader_7_steps_ [[function_constant(16)]]; +STEEL_CONST int rader_6_steps_ [[function_constant(17)]]; +STEEL_CONST int rader_5_steps_ [[function_constant(18)]]; +STEEL_CONST int rader_4_steps_ [[function_constant(19)]]; +STEEL_CONST int rader_3_steps_ [[function_constant(20)]]; +STEEL_CONST int rader_2_steps_ [[function_constant(21)]]; + +// See "radix.h" for radix codelets +typedef void (*RadixFunc)(thread float2*, thread float2*); + +// Perform a single radix n butterfly with appropriate twiddles +template +METAL_FUNC void radix_butterfly( + int i, + int p, + thread float2* x, + thread short* indices, + thread float2* y) { + // i: the index in the overall DFT that we're processing. + // p: the size of the DFTs we're merging at this step. + // m: how many threads are working on this DFT. + int k, j; + + // Use faster bitwise operations when working with powers of two + constexpr bool radix_p_2 = (radix & (radix - 1)) == 0; + if (radix_p_2 && is_power_of_2_) { + constexpr short power = __builtin_ctz(radix); + k = i & (p - 1); + j = ((i - k) << power) + k; + } else { + k = i % p; + j = (i / p) * radix * p + k; + } + + // Apply twiddles + if (p > 1) { + float2 twiddle_1 = get_twiddle(k, radix * p); + float2 twiddle = twiddle_1; + x[1] = complex_mul(x[1], twiddle); + + STEEL_PRAGMA_UNROLL + for (int t = 2; t < radix; t++) { + twiddle = complex_mul(twiddle, twiddle_1); + x[t] = complex_mul(x[t], twiddle); + } + } + + radix_func(x, y); + + STEEL_PRAGMA_UNROLL + for (int t = 0; t < radix; t++) { + indices[t] = j + t * p; + } +} + +// Perform all the radix steps required for a +// particular radix size n. +template +METAL_FUNC void radix_n_steps( + int i, + thread int* p, + int m, + int n, + int num_steps, + thread float2* inputs, + thread short* indices, + thread float2* values, + threadgroup float2* buf) { + int m_r = n / radix; + // When combining different sized radices, we have to do + // multiple butterflies in a single thread. + // E.g. n = 28 = 4 * 7 + // 4 threads, 7 elems_per_thread + // All threads do 1 radix7 butterfly. + // 3 threads do 2 radix4 butterflies. + // 1 thread does 1 radix4 butterfly. + int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix; + + int index = 0; + int r_index = 0; + for (int s = 0; s < num_steps; s++) { + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + inputs[r] = buf[index + r * m_r]; + } + radix_butterfly( + index, *p, inputs, indices + t * radix, values + t * radix); + } + } + + // Wait until all threads have read their inputs into thread local mem + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int t = 0; t < max_radices_per_thread; t++) { + index = i + t * m; + if (index < m_r) { + for (int r = 0; r < radix; r++) { + r_index = t * radix + r; + buf[indices[r_index]] = values[r_index]; + } + } + } + + // Wait until all threads have written back to threadgroup mem + threadgroup_barrier(mem_flags::mem_threadgroup); + *p *= radix; + } +} + +#define RADIX_STEP(radix, radix_func, num_steps) \ + radix_n_steps( \ + fft_idx, p, m, n, num_steps, inputs, indices, values, buf); + +template +METAL_FUNC void +perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) { + float2 inputs[MAX_RADIX]; + short indices[MAX_OUTPUT_SIZE]; + float2 values[MAX_OUTPUT_SIZE]; + + RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_); + RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_); + RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_); + RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_); + RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_); + RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_); + RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_); + RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_); + RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_); +} + +// Each FFT is computed entirely in shared GPU memory. +// +// N is decomposed into radix-n DFTs: +// e.g. 128 = 2 * 4 * 4 * 4 +template +[[kernel]] void fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void rader_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* raders_b_q [[buffer(2)]], + const device short* raders_g_q [[buffer(3)]], + const device short* raders_g_minus_q [[buffer(4)]], + constant const int& n, + constant const int& batch_size, + constant const int& rader_n, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Use Rader's algorithm to compute fast FFTs + // when a prime factor `p` of `n` is greater than 13 but + // has `p - 1` Stockham decomposable into to prime factors <= 13. + // + // E.g. n = 102 + // = 2 * 3 * 17 + // . = 2 * 3 * RADER(16) + // . = 2 * 3 * RADER(4 * 4) + // + // In numpy: + // x_perm = x[g_q] + // y = np.fft.fft(x_perm) * b_q + // z = np.fft.ifft(y) + x[0] + // out = z[g_minus_q] + // out[0] = x[1:].sum() + // + // Where the g_q and g_minus_q are permutations formed + // by the group under multiplicative modulo N using the + // primitive root of N and b_q is a constant. + // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm + // + // Rader's uses fewer operations than Bluestein's and so + // is more accurate. It's also faster in most cases. + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // The number of the threads we're using for each DFT + int m = grid.z; + + int fft_idx = elem.z; + int tg_idx = elem.y * n; + threadgroup float2* buf = &shared_in[tg_idx]; + + // rader_m = n / rader_n; + int rader_m = rader_m_; + + // We have to load two x_0s for each thread since sometimes + // elems_per_thread_ crosses a boundary. + // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4 + // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8 + // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + short x_0_index = + metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1); + float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]}; + + // Do the Rader permutation in shared memory + float2 temp[MAX_RADIX]; + int max_index = n - rader_m - 1; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q = raders_g_q[index / rader_m]; + temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[index + rader_m] = temp[e]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader FFT on x[rader_m:] + int p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + // x_1 + ... + x_n is computed for us in the first FFT step so + // we save it in the first rader_m indices of the array for later. + int x_sum_index = metal::min(fft_idx, rader_m - 1); + buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)]; + + float2 inv = {1.0f, -1.0f}; + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short interleaved_index = + index / rader_m + (index % rader_m) * (rader_n - 1); + temp[e] = complex_mul( + buf[rader_m + interleaved_index], + raders_b_q[interleaved_index % (rader_n - 1)]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + buf[rader_m + index] = temp[e] * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Rader IFFT on x[rader_m:] + p = 1; + perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); + + float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)}; + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1); + short diff_index = index / (rader_n - 1) - x_0_index; + temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index]; + } + + // Use the sum of elements that was computed in the first FFT + float2 x_sum = buf[x_0_index] + x_0[0]; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int e = 0; e < elems_per_thread_; e++) { + short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); + short g_q_index = index % (rader_n - 1); + short g_q = raders_g_minus_q[g_q_index]; + short out_index = index - g_q_index + g_q + (index / (rader_n - 1)); + buf[out_index] = temp[e]; + } + + buf[x_0_index * rader_n] = x_sum; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + p = rader_n; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write(); +} + +template +[[kernel]] void bluestein_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + const device float2* w_q [[buffer(2)]], + const device float2* w_k [[buffer(3)]], + constant const int& length, + constant const int& n, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Computes arbitrary length FFTs with Bluestein's algorithm + // + // In numpy: + // bluestein_n = next_power_of_2(2*n - 1) + // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q) + // + // Where w_k and w_q are precomputed on CPU in high precision as: + // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2)) + // w_q = np.fft.fft(1/w_k[-n:]) + threadgroup float2 shared_in[tg_mem_size]; + + thread ReadWriter read_writer = ReadWriter( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_padded(length, w_k); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + int fft_idx = elem.z; // Thread index in DFT + int m = grid.z; // Threads per DFT + int tg_idx = elem.y * n; // Index of this DFT in threadgroup + threadgroup float2* buf = &shared_in[tg_idx]; + + // fft + perform_fft(fft_idx, &p, m, n, buf); + + float2 inv = float2(1.0f, -1.0f); + for (int t = 0; t < elems_per_thread_; t++) { + int index = fft_idx + t * m; + buf[index] = complex_mul(buf[index], w_q[index]) * inv; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // ifft + p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_padded(length, w_k); +} + +template < + int tg_mem_size, + typename in_T, + typename out_T, + int step, + bool real = false> +[[kernel]] void four_step_fft( + const device in_T* in [[buffer(0)]], + device out_T* out [[buffer(1)]], + constant const int& n1, + constant const int& n2, + constant const int& batch_size, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Fast four step FFT implementation for powers of 2. + int overall_n = n1 * n2; + int n = step == 0 ? n1 : n2; + int stride = step == 0 ? n2 : n1; + + // The number of the threads we're using for each DFT + int m = grid.z; + int fft_idx = elem.z; + + threadgroup float2 shared_in[tg_mem_size]; + threadgroup float2* buf = &shared_in[elem.y * n]; + + using read_writer_t = ReadWriter; + read_writer_t read_writer = read_writer_t( + in, + &shared_in[0], + out, + n, + batch_size, + elems_per_thread_, + elem, + grid, + inv_); + + if (read_writer.out_of_bounds()) { + return; + }; + read_writer.load_strided(stride, overall_n); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + int p = 1; + perform_fft(fft_idx, &p, m, n, buf); + + read_writer.write_strided(stride, overall_n); +} diff --git a/dist/include/mlx/backend/metal/kernels/fft/radix.h b/dist/include/mlx/backend/metal/kernels/fft/radix.h new file mode 100644 index 0000000..bd61eef --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/fft/radix.h @@ -0,0 +1,328 @@ +// Copyright © 2024 Apple Inc. + +/* Radix kernels + +We provide optimized, single threaded Radix codelets +for n=2,3,4,5,6,7,8,10,11,12,13. + +For n=2,3,4,5,6 we hand write the codelets. +For n=8,10,12 we combine smaller codelets. +For n=7,11,13 we use Rader's algorithm which decomposes +them into (n-1)=6,10,12 codelets. */ + +#pragma once + +#include +#include +#include + +METAL_FUNC float2 complex_mul(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); +} + +// Complex mul followed by conjugate +METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { + return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x); +} + +// Compute an FFT twiddle factor +METAL_FUNC float2 get_twiddle(int k, int p) { + float theta = -2.0f * k * M_PI_F / p; + + float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)}; + return twiddle; +} + +METAL_FUNC void radix2(thread float2* x, thread float2* y) { + y[0] = x[0] + x[1]; + y[1] = x[0] - x[1]; +} + +METAL_FUNC void radix3(thread float2* x, thread float2* y) { + float pi_2_3 = -0.8660254037844387; + + float2 a_1 = x[1] + x[2]; + float2 a_2 = x[1] - x[2]; + + y[0] = x[0] + a_1; + float2 b_1 = x[0] - 0.5 * a_1; + float2 b_2 = pi_2_3 * a_2; + + float2 b_2_j = {-b_2.y, b_2.x}; + y[1] = b_1 + b_2_j; + y[2] = b_1 - b_2_j; +} + +METAL_FUNC void radix4(thread float2* x, thread float2* y) { + float2 z_0 = x[0] + x[2]; + float2 z_1 = x[0] - x[2]; + float2 z_2 = x[1] + x[3]; + float2 z_3 = x[1] - x[3]; + float2 z_3_i = {z_3.y, -z_3.x}; + + y[0] = z_0 + z_2; + y[1] = z_1 + z_3_i; + y[2] = z_0 - z_2; + y[3] = z_1 - z_3_i; +} + +METAL_FUNC void radix5(thread float2* x, thread float2* y) { + float2 root_5_4 = 0.5590169943749475; + float2 sin_2pi_5 = 0.9510565162951535; + float2 sin_1pi_5 = 0.5877852522924731; + + float2 a_1 = x[1] + x[4]; + float2 a_2 = x[2] + x[3]; + float2 a_3 = x[1] - x[4]; + float2 a_4 = x[2] - x[3]; + + float2 a_5 = a_1 + a_2; + float2 a_6 = root_5_4 * (a_1 - a_2); + float2 a_7 = x[0] - a_5 / 4; + float2 a_8 = a_7 + a_6; + float2 a_9 = a_7 - a_6; + float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4; + float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4; + float2 a_10_j = {a_10.y, -a_10.x}; + float2 a_11_j = {a_11.y, -a_11.x}; + + y[0] = x[0] + a_5; + y[1] = a_8 + a_10_j; + y[2] = a_9 + a_11_j; + y[3] = a_9 - a_11_j; + y[4] = a_8 - a_10_j; +} + +METAL_FUNC void radix6(thread float2* x, thread float2* y) { + float sin_pi_3 = 0.8660254037844387; + float2 a_1 = x[2] + x[4]; + float2 a_2 = x[0] - a_1 / 2; + float2 a_3 = sin_pi_3 * (x[2] - x[4]); + float2 a_4 = x[5] + x[1]; + float2 a_5 = x[3] - a_4 / 2; + float2 a_6 = sin_pi_3 * (x[5] - x[1]); + float2 a_7 = x[0] + a_1; + + float2 a_3_i = {a_3.y, -a_3.x}; + float2 a_6_i = {a_6.y, -a_6.x}; + float2 a_8 = a_2 + a_3_i; + float2 a_9 = a_2 - a_3_i; + float2 a_10 = x[3] + a_4; + float2 a_11 = a_5 + a_6_i; + float2 a_12 = a_5 - a_6_i; + + y[0] = a_7 + a_10; + y[1] = a_8 - a_11; + y[2] = a_9 + a_12; + y[3] = a_7 - a_10; + y[4] = a_8 + a_11; + y[5] = a_9 - a_12; +} + +METAL_FUNC void radix7(thread float2* x, thread float2* y) { + // Rader's algorithm + float2 inv = {1 / 6.0, -1 / 6.0}; + + // fft + float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]}; + radix6(in1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879)); + y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629)); + y[4] = complex_mul_conj(y[4], float2(0, -2.64575131)); + y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629)); + y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879)); + + // ifft + radix6(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[5] = x[2] * inv + x[0]; + y[4] = x[3] * inv + x[0]; + y[6] = x[4] * inv + x[0]; + y[2] = x[5] * inv + x[0]; + y[3] = x[6] * inv + x[0]; +} + +METAL_FUNC void radix8(thread float2* x, thread float2* y) { + float cos_pi_4 = 0.7071067811865476; + float2 w_0 = {cos_pi_4, -cos_pi_4}; + float2 w_1 = {-cos_pi_4, -cos_pi_4}; + float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]}; + radix4(temp, x); + radix4(temp + 4, x + 4); + + y[0] = x[0] + x[4]; + y[4] = x[0] - x[4]; + float2 x_5 = complex_mul(x[5], w_0); + y[1] = x[1] + x_5; + y[5] = x[1] - x_5; + float2 x_6 = {x[6].y, -x[6].x}; + y[2] = x[2] + x_6; + y[6] = x[2] - x_6; + float2 x_7 = complex_mul(x[7], w_1); + y[3] = x[3] + x_7; + y[7] = x[3] - x_7; +} + +template +METAL_FUNC void radix10(thread float2* x, thread float2* y) { + float2 w[4]; + w[0] = {0.8090169943749475, -0.5877852522924731}; + w[1] = {0.30901699437494745, -0.9510565162951535}; + w[2] = {-w[1].x, w[1].y}; + w[3] = {-w[0].x, w[0].y}; + + if (raders_perm) { + float2 temp[10] = { + x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } else { + float2 temp[10] = { + x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]}; + radix5(temp, x); + radix5(temp + 5, x + 5); + } + + y[0] = x[0] + x[5]; + y[5] = x[0] - x[5]; + for (int t = 1; t < 5; t++) { + float2 a = complex_mul(x[t + 5], w[t - 1]); + y[t] = x[t] + a; + y[t + 5] = x[t] - a; + } +} + +METAL_FUNC void radix11(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 10.0, -1 / 10.0}; + + // fft + radix10(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649)); + y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656)); + y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479)); + y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150)); + y[6] = complex_mul_conj(y[6], float2(0, -3.31662479)); + y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150)); + y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479)); + y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656)); + y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649)); + + // ifft + radix10(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[6] = x[2] * inv + x[0]; + y[3] = x[3] * inv + x[0]; + y[7] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[10] = x[6] * inv + x[0]; + y[5] = x[7] * inv + x[0]; + y[8] = x[8] * inv + x[0]; + y[4] = x[9] * inv + x[0]; + y[2] = x[10] * inv + x[0]; +} + +template +METAL_FUNC void radix12(thread float2* x, thread float2* y) { + float2 w[6]; + float sin_pi_3 = 0.8660254037844387; + w[0] = {sin_pi_3, -0.5}; + w[1] = {0.5, -sin_pi_3}; + w[2] = {0, -1}; + w[3] = {-0.5, -sin_pi_3}; + w[4] = {-sin_pi_3, -0.5}; + + if (raders_perm) { + float2 temp[12] = { + x[0], + x[3], + x[2], + x[11], + x[8], + x[9], + x[1], + x[7], + x[5], + x[10], + x[4], + x[6]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } else { + float2 temp[12] = { + x[0], + x[2], + x[4], + x[6], + x[8], + x[10], + x[1], + x[3], + x[5], + x[7], + x[9], + x[11]}; + radix6(temp, x); + radix6(temp + 6, x + 6); + } + + y[0] = x[0] + x[6]; + y[6] = x[0] - x[6]; + for (int t = 1; t < 6; t++) { + float2 a = complex_mul(x[t + 6], w[t - 1]); + y[t] = x[t] + a; + y[t + 6] = x[t] - a; + } +} + +METAL_FUNC void radix13(thread float2* x, thread float2* y) { + // Raders Algorithm + float2 inv = {1 / 12.0, -1 / 12.0}; + + // fft + radix12(x + 1, y + 1); + + y[0] = y[1] + x[0]; + + // b_q + y[1] = complex_mul_conj(y[1], float2(-1, 0)); + y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669)); + y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823)); + y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161)); + y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690)); + y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267)); + y[7] = complex_mul_conj(y[7], float2(3.60555128, 0)); + y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267)); + y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690)); + y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161)); + y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823)); + y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669)); + + // ifft + radix12(y + 1, x + 1); + + y[1] = x[1] * inv + x[0]; + y[7] = x[2] * inv + x[0]; + y[10] = x[3] * inv + x[0]; + y[5] = x[4] * inv + x[0]; + y[9] = x[5] * inv + x[0]; + y[11] = x[6] * inv + x[0]; + y[12] = x[7] * inv + x[0]; + y[6] = x[8] * inv + x[0]; + y[3] = x[9] * inv + x[0]; + y[8] = x[10] * inv + x[0]; + y[4] = x[11] * inv + x[0]; + y[2] = x[12] * inv + x[0]; +} \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/fft/readwrite.h b/dist/include/mlx/backend/metal/kernels/fft/readwrite.h new file mode 100644 index 0000000..0dc6299 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/fft/readwrite.h @@ -0,0 +1,624 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/backend/metal/kernels/fft/radix.h" + +/* FFT helpers for reading and writing from/to device memory. + +For many sizes, GPU FFTs are memory bandwidth bound so +read/write performance is important. + +Where possible, we read 128 bits sequentially in each thread, +coalesced with accesses from adjacent threads for optimal performance. + +We implement specialized reading/writing for: + - FFT + - RFFT + - IRFFT + +Each with support for: + - Contiguous reads + - Padded reads + - Strided reads +*/ + +#define MAX_RADIX 13 + +using namespace metal; + +template < + typename in_T, + typename out_T, + int step = 0, + bool four_step_real = false> +struct ReadWriter { + const device in_T* in; + threadgroup float2* buf; + device out_T* out; + int n; + int batch_size; + int elems_per_thread; + uint3 elem; + uint3 grid; + int threads_per_tg; + bool inv; + + // Used for strided access + int strided_device_idx = 0; + int strided_shared_idx = 0; + + METAL_FUNC ReadWriter( + const device in_T* in_, + threadgroup float2* buf_, + device out_T* out_, + const short n_, + const int batch_size_, + const short elems_per_thread_, + const uint3 elem_, + const uint3 grid_, + const bool inv_) + : in(in_), + buf(buf_), + out(out_), + n(n_), + batch_size(batch_size_), + elems_per_thread(elems_per_thread_), + elem(elem_), + grid(grid_), + inv(inv_) { + // Account for padding on last threadgroup + threads_per_tg = elem.x == grid.x - 1 + ? (batch_size - (grid.x - 1) * grid.y) * grid.z + : grid.y * grid.z; + } + + // ifft(x) = 1/n * conj(fft(conj(x))) + METAL_FUNC float2 post_in(float2 elem) const { + return inv ? float2(elem.x, -elem.y) : elem; + } + + // Handle float case for generic RFFT alg + METAL_FUNC float2 post_in(float elem) const { + return float2(elem, 0); + } + + METAL_FUNC float2 pre_out(float2 elem) const { + return inv ? float2(elem.x / n, -elem.y / n) : elem; + } + + METAL_FUNC float2 pre_out(float2 elem, int length) const { + return inv ? float2(elem.x / length, -elem.y / length) : elem; + } + + METAL_FUNC bool out_of_bounds() const { + // Account for possible extra threadgroups + int grid_index = elem.x * grid.y + elem.y; + return grid_index >= batch_size; + } + + METAL_FUNC void load() const { + size_t batch_idx = size_t(elem.x * grid.y) * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + // 2 complex64s = 128 bits + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + buf[index] = post_in(in[batch_idx + index]); + buf[index + 1] = post_in(in[batch_idx + index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + buf[index] = post_in(in[batch_idx + index]); + } + } + + METAL_FUNC void write() const { + size_t batch_idx = size_t(elem.x * grid.y) * n; + short tg_idx = elem.y * grid.z + elem.z; + short max_index = grid.y * n - 2; + + constexpr int read_width = 2; + for (short e = 0; e < (elems_per_thread / read_width); e++) { + short index = read_width * tg_idx + read_width * threads_per_tg * e; + index = metal::min(index, max_index); + // vectorized reads + out[batch_idx + index] = pre_out(buf[index]); + out[batch_idx + index + 1] = pre_out(buf[index + 1]); + } + max_index += 1; + if (elems_per_thread % 2 != 0) { + short index = tg_idx + + read_width * threads_per_tg * (elems_per_thread / read_width); + index = metal::min(index, max_index); + out[batch_idx + index] = pre_out(buf[index]); + } + } + + // Padded IO for Bluestein's algorithm + METAL_FUNC void load_padded(int length, const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = post_in(in[batch_idx + index]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0.0; + } + } + } + + METAL_FUNC void write_padded(int length, const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; + int fft_idx = elem.z; + int m = grid.z; + float2 inv_factor = {1.0f / n, -1.0f / n}; + + threadgroup float2* seq_buf = buf + elem.y * n; + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = seq_buf[index + length - 1] * inv_factor; + out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length); + } + } + } + + // Strided IO for four step FFT + METAL_FUNC void compute_strided_indices(int stride, int overall_n) { + // Use the batch threadgroup dimension to coalesce memory accesses: + // e.g. stride = 12 + // device | shared mem + // 0 1 2 3 | 0 12 - - + // - - - - | 1 13 - - + // - - - - | 2 14 - - + // 12 13 14 15 | 3 15 - - + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread; + } + + // Four Step FFT First Step + METAL_FUNC void load_strided(int stride, int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + buf[strided_shared_idx + e] = + post_in(in[strided_device_idx + e * stride]); + } + } + + METAL_FUNC void write_strided(int stride, int overall_n) { + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + int combined_idx = (strided_device_idx + e * stride) % overall_n; + int ij = (combined_idx / stride) * (combined_idx % stride); + // Apply four step twiddles at end of first step + float2 twiddle = get_twiddle(ij, overall_n); + out[strided_device_idx + e * stride] = complex_mul(output, twiddle); + } + } +}; + +// Four Step FFT Second Step +template <> +METAL_FUNC void ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + for (int e = 0; e < elems_per_thread; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = pre_out(output, overall_n); + } +} + +// For RFFT, we interleave batches of two real sequences into one complex one: +// +// z_k = x_k + j.y_k +// X_k = (Z_k + Z_(N-k)*) / 2 +// Y_k = -j * ((Z_k - Z_(N-k)*) / 2) +// +// This roughly doubles the throughput over the regular FFT. +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for RFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + seq_buf[index].x = in[batch_idx + index]; + seq_buf[index].y = in[batch_idx + index + next_in]; + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + short n_over_2 = (n / 2) + 1; + + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + float2 conj = {1, -1}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, n_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + out[batch_idx + index] = {seq_buf[index].x, 0}; + out[batch_idx + index + next_out] = {seq_buf[index].y, 0}; + } else { + float2 x_k = seq_buf[index]; + float2 x_n_minus_k = seq_buf[n - index] * conj; + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + if (index < length) { + float2 elem = + float2(in[batch_idx + index], in[batch_idx + index + next_in]); + seq_buf[index] = complex_mul(elem, w_k[index]); + } else { + seq_buf[index] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + int length_over_2 = (length / 2) + 1; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + float2 conj = {1, -1}; + float2 inv_factor = {1.0f / n, -1.0f / n}; + float2 minus_j = {0, -1}; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread / 2 + 1; e++) { + int index = metal::min(fft_idx + e * m, length_over_2 - 1); + // x_0 = z_0.real + // y_0 = z_0.imag + if (index == 0) { + float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor); + out[batch_idx + index] = float2(elem.x, 0); + out[batch_idx + index + next_out] = float2(elem.y, 0); + } else { + float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor); + float2 x_n_minus_k = complex_mul( + w_k[length - index], seq_buf[length - index] * inv_factor); + x_n_minus_k *= conj; + // w_k should happen before this extraction + out[batch_idx + index] = (x_k + x_n_minus_k) / 2; + out[batch_idx + index + next_out] = + complex_mul(((x_k - x_n_minus_k) / 2), minus_j); + } + } +} + +// For IRFFT, we do the opposite +// +// Z_k = X_k + j.Y_k +// x_k = Re(Z_k) +// Y_k = Imag(Z_k) +template <> +METAL_FUNC bool ReadWriter::out_of_bounds() const { + int grid_index = elem.x * grid.y + elem.y; + // We pack two sequences into one for IRFFTs + return grid_index * 2 >= batch_size; +} + +template <> +METAL_FUNC void ReadWriter::load() const { + short n_over_2 = (n / 2) + 1; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + // NumPy forces first input to be real + bool first_val = index == 0; + // NumPy forces last input on even irffts to be real + bool last_val = n % 2 == 0 && index == n_over_2 - 1; + if (first_val || last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + seq_buf[index] = x + complex_mul(y, plus_j); + seq_buf[index].y = -seq_buf[index].y; + if (index > 0 && !last_val) { + seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[n - index].y = -seq_buf[n - index].y; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write() const { + int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; + + short m = grid.z; + short fft_idx = elem.z; + + for (int e = 0; e < elems_per_thread; e++) { + int index = metal::min(fft_idx + e * m, n - 1); + out[batch_idx + index] = seq_buf[index].x / n; + out[batch_idx + index + next_out] = seq_buf[index].y / -n; + } +} + +template <> +METAL_FUNC void ReadWriter::load_padded( + int length, + const device float2* w_k) const { + int n_over_2 = (n / 2) + 1; + int length_over_2 = (length / 2) + 1; + + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; + threadgroup float2* seq_buf = buf + elem.y * n; + + // No out of bounds accesses on odd batch sizes + int grid_index = elem.x * grid.y + elem.y; + short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 + ? 0 + : length_over_2; + + short m = grid.z; + short fft_idx = elem.z; + + float2 conj = {1, -1}; + float2 plus_j = {0, 1}; + + for (int t = 0; t < elems_per_thread / 2 + 1; t++) { + int index = metal::min(fft_idx + t * m, n_over_2 - 1); + float2 x = in[batch_idx + index]; + float2 y = in[batch_idx + index + next_in]; + if (index < length_over_2) { + bool last_val = length % 2 == 0 && index == length_over_2 - 1; + if (last_val) { + x = float2(x.x, 0); + y = float2(y.x, 0); + } + float2 elem1 = x + complex_mul(y, plus_j); + seq_buf[index] = complex_mul(elem1 * conj, w_k[index]); + if (index > 0 && !last_val) { + float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j); + seq_buf[length - index] = + complex_mul(elem2 * conj, w_k[length - index]); + } + } else { + short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2); + seq_buf[pad_index] = 0; + seq_buf[pad_index + 1] = 0; + } + } +} + +template <> +METAL_FUNC void ReadWriter::write_padded( + int length, + const device float2* w_k) const { + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; + threadgroup float2* seq_buf = buf + elem.y * n + length - 1; + + int grid_index = elem.x * grid.y + elem.y; + short next_out = + batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; + + short m = grid.z; + short fft_idx = elem.z; + + float2 inv_factor = {1.0f / n, -1.0f / n}; + for (int e = 0; e < elems_per_thread; e++) { + int index = fft_idx + e * m; + if (index < length) { + float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]); + out[batch_idx + index] = output.x / length; + out[batch_idx + index + next_out] = output.y / -length; + } + } +} + +// Four Step RFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + // Don't invert between steps + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + int coalesce_width = grid.y; + int tg_idx = elem.y * grid.z + elem.z; + int outer_batch_size = stride / coalesce_width; + + int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + + overall_n_over_2 * (elem.x / outer_batch_size); + strided_device_idx = strided_batch_idx + + tg_idx / coalesce_width * elems_per_thread / 2 * stride + + tg_idx % coalesce_width; + strided_shared_idx = (tg_idx % coalesce_width) * n + + tg_idx / coalesce_width * elems_per_thread / 2; + for (int e = 0; e < elems_per_thread / 2; e++) { + float2 output = buf[strided_shared_idx + e]; + out[strided_device_idx + e * stride] = output; + } + + // Add on n/2 + 1 element + if (tg_idx == 0 && elem.x % outer_batch_size == 0) { + out[strided_batch_idx + overall_n / 2] = buf[n / 2]; + } +} + +// Four Step IRFFT +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + int overall_n_over_2 = overall_n / 2 + 1; + auto conj = float2(1, -1); + + compute_strided_indices(stride, overall_n); + // Translate indices in terms of N - k + for (int e = 0; e < elems_per_thread; e++) { + int device_idx = strided_device_idx + e * stride; + int overall_batch = device_idx / overall_n; + int overall_index = device_idx % overall_n; + if (overall_index < overall_n_over_2) { + device_idx -= overall_batch * (overall_n - overall_n_over_2); + buf[strided_shared_idx + e] = in[device_idx] * conj; + } else { + int conj_idx = overall_n - overall_index; + device_idx = overall_batch * overall_n_over_2 + conj_idx; + buf[strided_shared_idx + e] = in[device_idx]; + } + } +} + +template <> +METAL_FUNC void +ReadWriter::load_strided( + int stride, + int overall_n) { + // Silence compiler warnings + (void)stride; + (void)overall_n; + bool default_inv = inv; + inv = false; + load(); + inv = default_inv; +} + +template <> +METAL_FUNC void +ReadWriter::write_strided( + int stride, + int overall_n) { + compute_strided_indices(stride, overall_n); + + for (int e = 0; e < elems_per_thread; e++) { + out[strided_device_idx + e * stride] = + pre_out(buf[strided_shared_idx + e], overall_n).x; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/fp4.h b/dist/include/mlx/backend/metal/kernels/fp4.h new file mode 100644 index 0000000..e701adc --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/fp4.h @@ -0,0 +1,59 @@ +#pragma once + +constexpr constant static float FP4_LUT[16] = { + +0.0f, + +0.5f, + +1.0f, + +1.5f, + +2.0f, + +3.0f, + +4.0f, + +6.0f, + -0.0f, + -0.5f, + -1.0f, + -1.5f, + -2.0f, + -3.0f, + -4.0f, + -6.0f}; + +struct fp4_e2m1 { + fp4_e2m1(float x) { + if (metal::isnan(x)) { + bits = 0x7; + return; + } + + const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0; + x = metal::abs(x); + + if (x > 5.0f) { + bits = 0x7; + } else if (x >= 3.5f) { + bits = 0x6; + } else if (x > 2.5f) { + bits = 0x5; + } else if (x >= 1.75f) { + bits = 0x4; + } else if (x > 1.25f) { + bits = 0x3; + } else if (x >= 0.75f) { + bits = 0x2; + } else if (x > 0.25f) { + bits = 0x1; + } else { + bits = 0x0; + } + bits |= sign_bit; + } + + operator float() { + half converted = as_type(ushort((bits & 7) << 9)); + converted *= 16384.0; + converted = bits & 8 ? -converted : converted; + return converted; + } + + uint8_t bits; +}; diff --git a/dist/include/mlx/backend/metal/kernels/fp8.h b/dist/include/mlx/backend/metal/kernels/fp8.h new file mode 100644 index 0000000..34816b4 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/fp8.h @@ -0,0 +1,82 @@ +#pragma once + +struct fp8_e4m3 { + template + fp8_e4m3(T f) { + // From PyTorch + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 + uint32_t fp8_max = 543 << 21; + uint32_t denorm_mask = 141 << 23; + uint32_t f_bits = as_type(static_cast(f)); + uint32_t sign = f_bits & 0x80000000; + f_bits ^= sign; + if (f_bits >= fp8_max) { + // Default behavior saturates to min/max + bits = 0x7E; + } else { + if (f_bits < (121 << 23)) { + f_bits = as_type( + as_type(f_bits) + as_type(denorm_mask)); + bits = static_cast(f_bits - denorm_mask); + } else { + // resulting mantissa is odd + uint8_t mant_odd = (f_bits >> 20) & 1; + f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; + f_bits += mant_odd; + bits = static_cast(f_bits >> 20); + } + } + bits |= static_cast(sign >> 24); + } + + operator float() { + // From PyTorch: + // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L46 + uint32_t w = static_cast(bits) << 24; + uint32_t sign = w & 0x80000000; + uint32_t nonsign = w & 0x7FFFFFFF; + + uint32_t renorm_shift = metal::clz(nonsign); + renorm_shift = renorm_shift > 4 ? renorm_shift - 4 : 0; + + int32_t inf_nan_mask = + (static_cast(nonsign + 0x01000000) >> 8) & 0x7F800000; + int32_t zero_mask = static_cast(nonsign - 1) >> 31; + uint32_t result = sign | + ((((nonsign << renorm_shift >> 4) + ((0x78 - renorm_shift) << 23)) | + inf_nan_mask) & + ~zero_mask); + return as_type(result); + } + + uint8_t bits; +}; + +struct fp8_e8m0 { + fp8_e8m0(float x) { + if (!metal::isfinite(x)) { + bits = 0xFF; + return; + } + if (x < 0.0f) { + bits = 0x00; + return; + } + float le = metal::log2(x); + int n = int(metal::round(le)); + + n = n < -127 ? -127 : n; + n = n > 127 ? 127 : n; + bits = static_cast(n + 127); + } + + operator bfloat16_t() { + uint16_t out = (bits == 0 ? 0x40 : (static_cast(bits) << 7)); + return as_type(out); + } + operator float() { + return static_cast(this->operator bfloat16_t()); + } + + uint8_t bits; +}; diff --git a/dist/include/mlx/backend/metal/kernels/fp_quantized.h b/dist/include/mlx/backend/metal/kernels/fp_quantized.h new file mode 100644 index 0000000..cae1bbd --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/fp_quantized.h @@ -0,0 +1,1804 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / 4; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +static inline T dequantize_scale(uint8_t s) { + return T(*(thread fp8_e8m0*)(&s)); +} + +template +struct Quantize { + uint8_t operator()(float x) { + if (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + float operator()(uint8_t x) { + if (bits == 8) { + return float(*(thread fp8_e4m3*)(&x)); + } else { + return float(*(thread fp4_e2m1*)(&x)); + } + } +}; + +template +inline void load_vector(const device T* x, thread U* x_thread) { + for (int i = 0; i < values_per_thread; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } +} + +template +inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { + for (int i = 0; i < N; i += 4) { + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1]; + x_thread[i + 2] = x[i + 2]; + x_thread[i + 3] = x[i + 3]; + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } +} + +template +inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { + U accum = 0; + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); + } + return scale * accum; +} + +template +inline U +qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) { + U accum = 0; + + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + + x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + + x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + + x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); + } + return scale * accum; +} + +template +inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) { + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * scale * Dequantize<4>{}(w[i]); + result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4); + } +} + +template +inline void dequantize( + const device uint8_t* w, + U scale, + threadgroup U* w_local, + const threadgroup U* lut) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * lut[w[i] & 0xf]; + w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf]; + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + + MLX_MTL_CONST short pack_factor = get_pack_factor<8>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device uint8_t* scales; + threadgroup T* lut; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device uint8_t* scales_, + const int src_ld_, + threadgroup T* dst_, + threadgroup T* lut_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + lut(lut_) { + if (simd_group_id == 0 && simd_lane_id < 16) { + lut[simd_lane_id] = static_cast(FP4_LUT[simd_lane_id]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, dst + i * pack_factor, lut); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + dst + i * pack_factor, + lut); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + } + } else { + scales++; + } + } else { + scales += group_stride; + } + } +}; + +template +METAL_FUNC void fp_qmv_quad_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 8; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd; + + U s = dequantize_scale(sl[0]); + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void fp_qmv_fast_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void fp_qmv_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack<32>(); + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + uint8_t s = sl[0]; + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s); + } + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += qdot(wl, x_thread, s); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + load_vector_safe(x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device auto* sl = scales + row * in_vec_size_g; + + U s = dequantize_scale(sl[0]); + result[row] += + qdot_safe(wl, x_thread, s, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void fp_qvm_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor<32>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int tn = 32 / pack_factor; + constexpr int block_size = SIMD_SIZE; + + using W_T = uint32_t; + const device W_T* ws = (const device W_T*)w; + + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 0; + thread U x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + + x += block_size; + scales += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = dequantize_scale(*scales); + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, result); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void fp_qmm_t_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void fp_qmm_n_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup T* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, N, Ws, lut, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device uint8_t*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device uint8_t*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void fp_qmv_quad( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmv_quad_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); +} + +template +[[kernel]] void fp_qmv_fast( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_qmv( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_qvm( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_qvm_split_k( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& final_block_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + fp_qvm_impl( + w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_qmm_t( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_qmm_n( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template +[[kernel]] void fp_gather_qmv_fast( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_gather_qmv( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmv_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template +[[kernel]] void fp_gather_qvm( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qvm_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_gather_qmm_t( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_t_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void fp_gather_qmm_n( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void fp_gather_qmm_rhs( + const device T* x, + const device uint32_t* w, + const device uint8_t* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + threadgroup T lut[16]; + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + lut, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template +[[kernel]] void fp_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device uint8_t* scales [[buffer(2)]], + uint2 tidx [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + size_t index = tidx.x + grid_dim.x * size_t(tidx.y); + + float scale; + float w_thread = w[index]; + if (use_mx_scale) { + scale = simd_max(abs(w_thread)); + } else { + float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); + float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); + scale = tidx.x < 16 ? w_max_l : w_max_r; + } + scale /= bits == 4 ? 6.0f : 448.0f; + + using ScaleType = metal::conditional_t; + auto s = ScaleType(scale); + uint8_t q_scale = s.bits; + scale = float(s); + + // Write out the scales and biases + size_t gindex = index / group_size; + if (index % group_size == 0) { + scales[gindex] = q_scale; + } + + uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); + if (bits == 4) { + uint8_t sval = simd_shuffle_down(output, 1); + output |= sval << bits; + } + constexpr int pack_factor = bits == 8 ? 1 : 2; + if (index % pack_factor == 0) { + out[index / pack_factor] = output; + } +} + +template +[[kernel]] void fp_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr bool use_mx_scale = group_size == 32; + constexpr int pack_factor = bits == 8 ? 1 : 2; + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * pack_factor; + size_t gindex = oindex / group_size; + + out += oindex; + + using ScaleType = metal::conditional_t; + auto q_scale = ((device ScaleType*)(scales))[gindex]; + auto scale = float(q_scale); + + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} diff --git a/dist/include/mlx/backend/metal/kernels/fp_quantized_nax.h b/dist/include/mlx/backend/metal/kernels/fp_quantized_nax.h new file mode 100644 index 0000000..28c4ec8 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/fp_quantized_nax.h @@ -0,0 +1,1059 @@ +// Copyright © 2025 Apple Inc. + +#include +#include + +#include "mlx/backend/metal/kernels/fp4.h" +#include "mlx/backend/metal/kernels/fp8.h" + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return wsize / 4; +} + +template +inline constexpr short get_bytes_per_pack() { + return wsize / 8; +} + +template +static inline T dequantize_scale(uint8_t s) { + return T(*(thread fp8_e8m0*)(&s)); +} + +template +struct Quantize { + uint8_t operator()(float x) { + if constexpr (bits == 8) { + return fp8_e4m3(x).bits; + } else { + return fp4_e2m1(x).bits; + } + } +}; + +template +struct Dequantize { + float operator()(uint8_t x) { + if constexpr (bits == 8) { + return float(*(thread fp8_e4m3*)(&x)); + } else { + return float(*(thread fp4_e2m1*)(&x)); + } + } +}; + +template +inline void dequantize( + const device uint8_t* w, + U scale, + threadgroup U* w_local, + const threadgroup U* lut) { + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = scale * lut[w[i] & 0xf]; + w_local[2 * i + 1] = scale * lut[(w[i] >> 4) & 0xf]; + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size> +struct QuantizedBlockLoader { + static_assert( + BCOLS % group_size == 0, + "The group size should be divisible by the columns"); + + MLX_MTL_CONST short pack_factor = get_pack_factor<8>(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short n_groups = BCOLS / group_size; + + static_assert( + (BCOLS_PACKED / n_reads) == n_groups, + "Other configurations are not yet supported"); + + const int src_ld; + const int tile_stride; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + const short group_id; + + threadgroup T* dst; + const device uint8_t* src; + const device uint8_t* scales; + threadgroup T* lut; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device uint8_t* scales_, + const int src_ld_, + threadgroup T* dst_, + threadgroup T* lut_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + group_id((bj * pack_factor) / group_size), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size + group_id), + lut(lut_) { + if (simd_group_id == 0 && simd_lane_id < 16) { + lut[simd_lane_id] = static_cast(FP4_LUT[simd_lane_id]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, dst + i * pack_factor, lut); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = dequantize_scale(*scales); + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + dst + i * pack_factor, + lut); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + // if (group_steps > 1) { + // group_step_cnt++; + // if (group_step_cnt == group_steps) { + // group_step_cnt = 0; + // scales++; + // } + // } else { + scales += n_groups; + // } + } else { + scales += n_groups * group_stride; + } + } +}; + +using namespace mlx::steel; + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +METAL_FUNC void fp_qmm_t_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup Wtype* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup Wtype* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + // Instantiate Loader + using loader_w_t = QuantizedBlockLoader< + Wtype, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the weight loader + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * BK_padded + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +METAL_FUNC void fp_qmm_n_impl( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]], + threadgroup Wtype* lut) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + (void)M; + + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + // const short num_els = min(BM, M - y_row); + // const short num_outs = min(BN, N - y_col); + loader_w_t loader_w(wl, scales, K, Ws, lut, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short ldb_tgp = BN_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = false; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + Atile.load(x + kk1, K); + Btile.template load(Ws + tn + kk1 * ldb_tgp); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + Dtile.store(y + tm * N + tn, N); +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device S*& scales, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + } else { + ulong2 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + } + y += tid.z * output_stride; +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_qmm_t_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + threadgroup Wtype Ws[BN * BK_padded]; + threadgroup Wtype lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + fp_qmm_t_impl( + w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_qmm_n_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + } + + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_t_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + + threadgroup Wtype Ws[BN * BK_padded]; + threadgroup Wtype lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_t_impl( + w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_n_nax( + const device uint32_t* w, + const device uint8_t* scales, + const device T* x, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T* y, + const constant int& K, + const constant int& N, + const constant int& M, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + threadgroup T lut[16]; + + adjust_matrix_offsets( + x, + w, + scales, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + tid); + fp_qmm_n_impl( + w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid, lut); +} + +template < + typename T, + int group_size, + const int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose, + typename Wtype = bfloat> +[[kernel]] void fp_gather_qmm_rhs_nax( + const device T* x, + const device uint32_t* w, + const device uint8_t* scales, + const device uint32_t* indices, + device T* y, + const constant int& M, + const constant int& N, + const constant int& K, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor<8>(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); + constexpr int BN_padded = (BN + 16 / sizeof(Wtype)); + + threadgroup Wtype lut[16]; + + using loader_w_t = QuantizedBlockLoader< + Wtype, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size>; + + threadgroup Wtype Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + NAXTile Dtile; + + Dtile.clear(); + + const device T* xn = x + tm * K; + + // Prepare threadgroup loading operations + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + transpose ? K : N, + Ws, + lut, + simd_group_id, + simd_lane_id); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load( + Ws + tn * BK_padded + kk1); + } else { + Btile.template load( + Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); + + if constexpr (transpose) { + Btile.template load( + Ws + tn * BK_padded + kk1); + } else { + Btile.template load( + Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + + // Store results to device memory + if constexpr (kAlignedN.value) { + if (m_lo_lim == 0 && m_hi_lim == SM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); + } + }); + }); + } +} diff --git a/dist/include/mlx/backend/metal/kernels/gemv_masked.h b/dist/include/mlx/backend/metal/kernels/gemv_masked.h new file mode 100644 index 0000000..96b0c28 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/gemv_masked.h @@ -0,0 +1,827 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +typedef struct _NoMask nomask_t; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> +struct GEMVKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + static_assert( + SN == 8 || SN == 16 || SN == 32, + "gemv block must have a width of 8, 16, or 32"); + + static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM rows + // and the corresponding scalar from the vector + // 2. The thread then multiplies and adds to accumulate its local result for + // the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated blockM outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; + + template + static METAL_FUNC void + load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = static_cast(src[src_offset + tn]); + } + } + + template + static METAL_FUNC void load_safe( + const device T* src, + thread U dst[TN], + const int src_offset = 0, + const int src_size = TN) { + if (src_offset + TN <= src_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = static_cast(src[src_offset + tn]); + } + } else { // Edgecase + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + dst[tn] = src_offset + tn < src_size + ? static_cast(src[src_offset + tn]) + : U(0); + } + } + } + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& matrix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + thread AccT result[TM] = {0}; + thread T inter[TN]; + thread AccT v_coeff[TN]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); + const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; + + int bm = (simdM + thrM) * TM; + int bn = (simdN + thrN) * TN; + + // Block position + int out_row = tid.x * blockM + bm; + + // Exit simdgroup if rows out of bound + if (out_row >= out_vec_size) + return; + + // Adjust tail simdgroup to ensure in bound reads + out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; + + int mat_mask_offset = + !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = T(0.); + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Advance matrix + mat += out_row * matrix_ld; + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockN); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Loop over in_vec in blocks of blockN + for (int i = 0; i < n_iter; ++i) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_unsafe(in_vec, v_coeff, bn); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + int mat_offset = 0; + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_unsafe(mat, inter, mat_offset + bn); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + + mat_offset += matrix_ld; + } + } + + bn += blockN; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + load_safe(in_vec, v_coeff, bn, in_size); + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + v_coeff[tn] *= block_scale; + } + } + + // Per thread work loop + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + // Load for the row + load_safe(&mat[tm * matrix_ld], inter, bn, in_size); + + // Accumulate results + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tm] += inter[tn] * v_coeff[tn]; + } + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; + if (thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + tgp_results[tm] = result[tm]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgn = 1; sgn < BN; sgn++) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + result[tm] += tgp_results[sgn * (blockM + TM) + tm]; + } + } + } + } + } + + // Write outputs + if (simdN == 0 && thrN == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + out_vec[out_row + tm] = static_cast(result[tm]); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + typename AccT = float> +struct GEMVTKernel { + MLX_MTL_CONST int threadsM = BM * SM; + MLX_MTL_CONST int threadsN = BN * SN; + + MLX_MTL_CONST int blockM = threadsM * TM; + MLX_MTL_CONST int blockN = threadsN * TN; + + static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); + + MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; + MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; + + MLX_MTL_CONST bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + MLX_MTL_CONST bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up + // into blocks of (blockM, blockN) divided among threadgroups + // - Every thread works on a block of (TM, TN) + // - We assume each threadgroup has (threadsN, threadsM, 1) threads + // + // 1. A thread loads TN elements each from mat along TM contiguous rows + // and the corresponding scalar from the vector + // 2. The thread then accumulates its local result for the block + // 3. At the end, each thread has accumulated results over all blocks across + // the rows. These are then summed up across the threadgroup + // 4. Each threadgroup writes its accumulated BN * TN outputs + // + // Edge case handling: + // - The threadgroup with the largest tid has blocks that exceed the matrix + // * The blocks that start outside the matrix are never read (thread results + // remain zero) + // * The last thread that partially overlaps with the matrix is shifted + // inwards such that the thread block fits exactly in the matrix + + MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; + MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; + + static METAL_FUNC void run( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + threadgroup AccT* tgp_memory [[threadgroup(0)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + // Appease compiler + (void)lid; + + // Thread local accumulation results + AccT result[TN] = {0}; + T inter[TN]; + AccT v_coeff[TM]; + + const int thrM = SN != 32 ? simd_lid / SN : 0; + const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); + + const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); + const int sgN = BN != 1 ? (simd_gid % BN) : 0; + + const int simdM = SM * sgM; + const int simdN = SN * sgN; + + int cm = (simdM + thrM); + int cn = (simdN + thrN); + + int bm = cm * TM; + int bn = cn * TN; + + int out_col = tid.x * blockN + bn; + + // Prepare mask offsets + const constant int* out_mask_strides = mask_strides; + const constant int* mat_mask_strides = + out_mask_strides + (has_output_mask ? 2 : 0); + const constant int* vec_mask_strides = + mat_mask_strides + (has_operand_mask ? 2 : 0); + + const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); + + const int out_mask_offset = + !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; + + int mat_mask_offset = + !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; + int vec_mask_offset = 0; + const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; + const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; + + T out_scale{1}; + + // Check output mask + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + // Write zeros and return if mask is 0 + if (!mask_out) { + if (cm == 0 && out_col < out_vec_size) { + if (out_col + TN <= out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + out_vec[out_col + tn] = T(0.); + } + } else { + for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { + out_vec[out_col + tn] = T(0.); + } + } + } + + return; + } + + // Store scalar if multiplicative mask + if (has_mul_output_mask) { + out_scale = T(mask_out); + } + } + + // Prepare for loop + constexpr const uniform loop_stride = make_uniform(blockM); + const uniform in_size = make_uniform(in_vec_size); + const uniform n_iter = in_size / loop_stride; + const uniform last_iter = loop_stride * n_iter; + const uniform leftover = in_size - last_iter; + + // Edgecase handling + if (out_col < out_vec_size) { + out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; + + // Per thread accumulation main loop + for (int i = 0; i < n_iter; ++i) { + // Adding a threadgroup_barrier improves performance slightly + // This is possibly it may help exploit cache better + threadgroup_barrier(mem_flags::mem_none); + + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); + } + + // Apply scale + if (has_mul_operand_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + v_coeff[tm] *= block_scale; + } + } + + MLX_MTL_PRAGMA_UNROLL + for (int tm = 0; tm < TM; tm++) { + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + + bm += blockM; + mat_mask_offset += mat_mask_step; + vec_mask_offset += vec_mask_step; + } + + if (leftover > 0) { + if (!has_operand_mask || + (bool(mat_mask[mat_mask_offset]) && + bool(vec_mask[vec_mask_offset]))) { + T block_scale{1}; + if (has_mul_operand_mask) { + block_scale = + T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); + } + + for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { + v_coeff[tm] = static_cast(in_vec[bm + tm]); + + if (has_mul_operand_mask) { + v_coeff[tm] *= block_scale; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; + } + + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += v_coeff[tm] * inter[tn]; + } + } + } + } + } + + // Apply out scale + if (has_mul_output_mask) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] *= out_scale; + } + } + + // Simdgroup accumulations + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + MLX_MTL_PRAGMA_UNROLL + for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { + result[tn] += simd_shuffle_down(result[tn], SN * sm); + } + } + + // Threadgroup accumulation results + if (needs_tgp_reduction) { + threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; + if (thrM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + tgp_results[tn] = result[tn]; + } + + threadgroup_barrier(mem_flags::mem_none); + + if (sgM == 0) { + MLX_MTL_PRAGMA_UNROLL + for (int sgm = 1; sgm < BM; sgm++) { + MLX_MTL_PRAGMA_UNROLL + for (int tn = 0; tn < TN; tn++) { + result[tn] += tgp_results[sgm * (blockN + TN) + tn]; + } + } + } + } + } + + // Threadgroup accumulation and writing out results + if (cm == 0 && out_col < out_vec_size) { + MLX_MTL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + out_vec[out_col + j] = static_cast(result[j]); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/// Matrix vector multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant int64_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVKernel; + threadgroup float tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant auto* mask_strides_mat = mask_batch_strides; + const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} + +/////////////////////////////////////////////////////////////////////////////// +/// Vector matrix multiplication +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + const int BM, /* Threadgroup rows (in simdgroups) */ + const int BN, /* Threadgroup cols (in simdgroups) */ + const int SM, /* Simdgroup rows (in threads) */ + const int SN, /* Simdgroup cols (in threads) */ + const int TM, /* Thread rows (in elements) */ + const int TN, /* Thread cols (in elements) */ + const bool kDoNCBatch> /* Batch ndim > 1 */ +[[kernel, max_total_threads_per_threadgroup(BM* BN * 32)]] void gemv_t_masked( + const device T* mat [[buffer(0)]], + const device T* in_vec [[buffer(1)]], + device T* out_vec [[buffer(3)]], + const constant int& in_vec_size [[buffer(4)]], + const constant int& out_vec_size [[buffer(5)]], + const constant int& marix_ld [[buffer(6)]], + const constant int& batch_ndim [[buffer(9)]], + const constant int* batch_shape [[buffer(10)]], + const constant int64_t* vector_batch_stride [[buffer(11)]], + const constant int64_t* matrix_batch_stride [[buffer(12)]], + const device out_mask_t* out_mask [[buffer(20)]], + const device op_mask_t* mat_mask [[buffer(21)]], + const device op_mask_t* vec_mask [[buffer(22)]], + const constant int* mask_strides [[buffer(23)]], + const constant int64_t* mask_batch_strides [[buffer(24)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using gemv_kernel = + GEMVTKernel; + threadgroup float tgp_memory + [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + // Update batch offsets + if (kDoNCBatch) { + in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); + mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); + + if (has_output_mask) { + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + const constant auto* mask_strides_mat = mask_batch_strides; + const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); + + mat_mask += batch_offsets.x; + vec_mask += batch_offsets.y; + } + + } else { + in_vec += tid.z * vector_batch_stride[0]; + mat += tid.z * matrix_batch_stride[0]; + + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += batch_ndim; + } + + if (has_operand_mask) { + mat_mask += tid.z * mask_batch_strides[0]; + vec_mask += tid.z * mask_batch_strides[batch_ndim]; + } + } + + out_vec += tid.z * out_vec_size; + + gemv_kernel::run( + mat, + in_vec, + out_vec, + in_vec_size, + out_vec_size, + marix_ld, + out_mask, + mat_mask, + vec_mask, + mask_strides, + gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, + tid, + lid, + simd_gid, + simd_lid); +} diff --git a/dist/include/mlx/backend/metal/kernels/hadamard.h b/dist/include/mlx/backend/metal/kernels/hadamard.h new file mode 100644 index 0000000..9f2311c --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/hadamard.h @@ -0,0 +1,182 @@ +// Copyright © 2024 Apple Inc. +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" + +using namespace metal; + +// Thread local Hadamard transform for 2^R +template +METAL_FUNC void radix_func(thread float* x) { + constexpr short logR = __builtin_ctz(R); + short h = 1; + STEEL_PRAGMA_UNROLL + for (short s = 0; s < logR; s++) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < R / 2; i++) { + short k = i & (h - 1); + short j = ((i - k) << 1) + k; + float a = x[j]; + float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + h <<= 1; + } +} + +template +[[kernel]] void hadamard_n( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size N = 2^k + // + // Equivalent to: + // from scipy.linalg import hadamard + // y = hadamard(len(x)) @ x + + constexpr short num_threads = N / max_radix; + constexpr short logN = __builtin_ctz(N); + constexpr short logR = __builtin_ctz(max_radix); + constexpr short num_steps = logN / logR; + constexpr short logFinal = logN % logR; + constexpr short final_radix = 1 << (logFinal); + + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; + + threadgroup T buf[N]; + + // Read values from device + if (stride == 1) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + float x[max_radix]; + short h = 1; + + STEEL_PRAGMA_UNROLL + for (short s = 0; s < num_steps; s++) { + short k = i & (h - 1); + short j = ((i - k) << logR) + k; + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < max_radix; r++) { + buf[j + h * r] = T(x[r]); + } + + h <<= logR; + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Do the final radix + // e.g. max_radix = 16 + // N = 1024 = 16 * 16 * 4 + if (final_radix > 1) { + // Each thread does multiple butterflies + STEEL_PRAGMA_UNROLL + for (int t = 0; t < max_radix / final_radix; t++) { + short index = i + t * num_threads; + short k = index & (h - 1); + short j = ((index - k) << logFinal) + k; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + x[r] = buf[j + h * r]; + } + + radix_func(x); + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < final_radix; r++) { + buf[j + h * r] = T(x[r]); + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write values to device + if (stride == 1) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; + } + } +} + +template +[[kernel]] void hadamard_m( + const device T* in [[buffer(0)]], + device T* out [[buffer(1)]], + constant const float& scale, + uint3 elem [[thread_position_in_grid]], + uint3 grid [[threads_per_grid]]) { + // Compute a Hadamard transform of size M + // using a naive O(M^2) codelet. + // + // This kernel is the second stage in the computation + // of a Hadamard transform of size M*N where N = 2^k. + + int index = elem.x * grid.y + elem.y; + short i = index % (N / read_width); + int batch_idx = index / (N / read_width) * M * N; + + float x[read_width][M]; + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + x[r][c] = in[batch_idx + c * N + i * read_width + r]; + } + } + + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + // This function is JIT compiled for M + // using the Hadamard matrix strings in `metal/hadamard.cpp` + hadamard_radix_m(x[r]); + } + + // Write back to device + STEEL_PRAGMA_UNROLL + for (short c = 0; c < M; c++) { + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale); + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/indexing/gather.h b/dist/include/mlx/backend/metal/kernels/indexing/gather.h new file mode 100644 index 0000000..8b93c01 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/indexing/gather.h @@ -0,0 +1,51 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing/indexing.h" + +template +METAL_FUNC void gather_impl( + const device T* src [[buffer(0)]], + device T* out [[buffer(1)]], + const constant int* src_shape [[buffer(2)]], + const constant int64_t* src_strides [[buffer(3)]], + const constant size_t& src_ndim [[buffer(4)]], + const constant int* slice_sizes [[buffer(5)]], + const constant int* axes [[buffer(6)]], + const thread Indices& indices, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + LocT src_idx = 0; + for (int i = 0; i < NIDX; ++i) { + LocT idx_loc; + if (IDX_NDIM == 0) { + idx_loc = 0; + } else if (IDX_NDIM == 1) { + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); + } else { + idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); + idx_loc += indices.row_contiguous[i] + ? index.y + : elem_to_loc( + index.y, + &indices.shapes[indices.ndim * i + 1], + &indices.strides[indices.ndim * i + 1], + indices.ndim - 1); + } + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); + src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); + } + + auto src_offset = + elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); + + LocT out_idx = index.z; + if (IDX_NDIM == 1) { + out_idx += static_cast(grid_dim.z) * index.x; + } else if (IDX_NDIM >= 2) { + out_idx += grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); + } + out[out_idx] = src[src_offset + src_idx]; +} diff --git a/dist/include/mlx/backend/metal/kernels/indexing/gather_axis.h b/dist/include/mlx/backend/metal/kernels/indexing/gather_axis.h new file mode 100644 index 0000000..bf490ad --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/indexing/gather_axis.h @@ -0,0 +1,44 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +template +[[kernel]] void gather_axis( + const device T* src [[buffer(0)]], + const device IdxT* indices [[buffer(1)]], + device T* out [[buffer(2)]], + const constant int* shape [[buffer(3)]], + const constant int64_t* src_strides [[buffer(4)]], + const constant int64_t* idx_strides [[buffer(5)]], + const constant size_t& ndim [[buffer(6)]], + const constant int& axis [[buffer(7)]], + const constant int& axis_size [[buffer(8)]], + const constant size_t& src_ax_stride [[buffer(9)]], + const constant size_t& idx_ax_stride [[buffer(10)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + LocT elem_idx = index.z * static_cast(grid_dim.x); + LocT out_idx = elem_idx * grid_dim.y + index.x; + + LocT idx_loc = index.y * static_cast(idx_ax_stride); + if (IdxC) { + idx_loc += out_idx; + } else { + idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); + } + + auto idx_val = indices[idx_loc]; + if (is_signed_v) { + idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val; + } + + LocT src_idx = idx_val * static_cast(src_ax_stride); + if (SrcC) { + src_idx += elem_idx * axis_size + index.x; + } else { + src_idx += elem_to_loc(elem_idx + index.x, shape, src_strides, ndim); + } + + out_idx += index.y * static_cast(grid_dim.x); + out[out_idx] = src[src_idx]; +} diff --git a/dist/include/mlx/backend/metal/kernels/indexing/gather_front.h b/dist/include/mlx/backend/metal/kernels/indexing/gather_front.h new file mode 100644 index 0000000..1389e4c --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/indexing/gather_front.h @@ -0,0 +1,24 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing/indexing.h" + +template +[[kernel]] void gather_front( + const device T* src, + const device IdxT* indices, + device T* out, + const constant int64_t& stride, + const constant int& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto idx = offset_neg_idx(indices[index.y], size); + LocT src_idx = static_cast(stride) * idx; + LocT out_idx = static_cast(stride) * index.y; + + int s_idx = N * index.x; + for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { + out[out_idx + s_idx] = src[src_idx + s_idx]; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/indexing/indexing.h b/dist/include/mlx/backend/metal/kernels/indexing/indexing.h new file mode 100644 index 0000000..2a4b4f9 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/indexing/indexing.h @@ -0,0 +1,23 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +template +struct Indices { + const array buffers; + const constant int* shapes; + const constant int64_t* strides; + const constant bool* row_contiguous; + const int ndim; +}; + +template +METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { + if (is_unsigned_v) { + return idx; + } else { + return (idx < 0) ? idx + size : idx; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/indexing/masked_scatter.h b/dist/include/mlx/backend/metal/kernels/indexing/masked_scatter.h new file mode 100644 index 0000000..1fd19e2 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/indexing/masked_scatter.h @@ -0,0 +1,38 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +template +[[kernel]] void masked_assign_impl( + const device bool* mask [[buffer(0)]], + const device uint* scatter_offsets [[buffer(1)]], + const device T* src [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int* src_shapes [[buffer(4)]], + const constant int64_t* src_strides [[buffer(5)]], + const constant int& src_ndim [[buffer(6)]], + const constant int64_t& src_batch_size [[buffer(7)]], + const constant int64_t& mask_batch_size [[buffer(8)]], + uint idx [[thread_position_in_grid]]) { + const bool mask_value = mask[idx]; + if (!mask_value) { + return; + } + + const uint src_index = scatter_offsets[idx]; + if (src_index >= src_batch_size) { + return; + } + + const uint batch_idx = idx / mask_batch_size; + + if (src_contiguous) { + out[idx] = src[batch_idx * src_batch_size + src_index]; + } else { + out[idx] = src[elem_to_loc( + batch_idx * src_batch_size + src_index, + src_shapes, + src_strides, + src_ndim)]; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/indexing/scatter.h b/dist/include/mlx/backend/metal/kernels/indexing/scatter.h new file mode 100644 index 0000000..f0217b3 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/indexing/scatter.h @@ -0,0 +1,59 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/indexing/indexing.h" + +template < + typename T, + typename IdxT, + typename Op, + int NIDX, + bool UPD_ROW_CONTIG, + int NWORK, + typename LocT> +METAL_FUNC void scatter_impl( + const device T* updates, + device mlx_atomic* out, + const constant int* upd_shape, + const constant int64_t* upd_strides, + const constant size_t& upd_ndim, + const constant size_t& upd_size, + const constant int* out_shape, + const constant int64_t* out_strides, + const constant size_t& out_ndim, + const constant int* axes, + const constant size_t& idx_size, + const thread Indices& indices, + uint2 gid [[thread_position_in_grid]]) { + Op op; + + auto ind_idx = gid.y * NWORK; + LocT out_offset = 0; + if (upd_size > 1) { + out_offset = elem_to_loc( + gid.x, upd_shape + indices.ndim, out_strides, out_ndim); + } + + for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { + LocT out_idx = out_offset; + for (int i = 0; i < NIDX; ++i) { + auto idx_loc = indices.row_contiguous[i] + ? ind_idx + : elem_to_loc( + ind_idx, + &indices.shapes[indices.ndim * i], + &indices.strides[indices.ndim * i], + indices.ndim); + auto ax = axes[i]; + auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); + out_idx += + static_cast(idx_val) * static_cast(out_strides[ax]); + } + auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; + if constexpr (!UPD_ROW_CONTIG) { + upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); + } + op.atomic_update(out, updates[upd_idx], out_idx); + } +} diff --git a/dist/include/mlx/backend/metal/kernels/indexing/scatter_axis.h b/dist/include/mlx/backend/metal/kernels/indexing/scatter_axis.h new file mode 100644 index 0000000..73fd7ab --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/indexing/scatter_axis.h @@ -0,0 +1,52 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +template < + typename T, + typename IdxT, + typename LocT, + typename Op, + bool UpdC, + bool IdxC> +[[kernel]] void scatter_axis( + const device T* upd [[buffer(0)]], + const device IdxT* indices [[buffer(1)]], + device mlx_atomic* out [[buffer(2)]], + const constant int* shape [[buffer(3)]], + const constant int64_t* upd_strides [[buffer(4)]], + const constant int64_t* idx_strides [[buffer(5)]], + const constant size_t& ndim [[buffer(6)]], + const constant int& axis [[buffer(7)]], + const constant int& out_axis_size [[buffer(8)]], + const constant size_t& upd_ax_stride [[buffer(9)]], + const constant size_t& idx_ax_stride [[buffer(10)]], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + Op op; + + LocT elem_idx = index.z * static_cast(grid_dim.x); + + LocT idx_loc = index.y * static_cast(idx_ax_stride); + if (IdxC) { + idx_loc += elem_idx * grid_dim.y + index.x; + } else { + idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); + } + + auto idx_val = indices[idx_loc]; + if (is_signed_v) { + idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val; + } + + LocT upd_idx = index.y * static_cast(upd_ax_stride); + if (UpdC) { + upd_idx += elem_idx * grid_dim.y + index.x; + } else { + upd_idx += elem_to_loc(elem_idx + index.x, shape, upd_strides, ndim); + } + + LocT out_idx = elem_idx * static_cast(out_axis_size) + + idx_val * grid_dim.x + index.x; + op.atomic_update(out, upd[upd_idx], out_idx); +} diff --git a/dist/include/mlx/backend/metal/kernels/logsumexp.h b/dist/include/mlx/backend/metal/kernels/logsumexp.h new file mode 100644 index 0000000..c746050 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/logsumexp.h @@ -0,0 +1,140 @@ +// Copyright © 2025 Apple Inc. + +template +[[kernel]] void logsumexp( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(ld[i] - maxval); + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } + } +} + +template +[[kernel]] void logsumexp_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * size_t(axis_size); + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= fast::exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += fast::exp(vals[i] - maxval); + } + } + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= fast::exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= fast::exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); + } +} diff --git a/dist/include/mlx/backend/metal/kernels/quantized.h b/dist/include/mlx/backend/metal/kernels/quantized.h new file mode 100644 index 0000000..bf63981 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/quantized.h @@ -0,0 +1,2502 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template +METAL_FUNC void qmv_quad_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; + constexpr int pack_factor = 32 / bits; + constexpr int values_per_thread = D / QUAD_SIZE; + constexpr int packs_per_thread = values_per_thread / pack_factor; + constexpr int scale_step_per_thread = group_size / values_per_thread; + constexpr int results_per_quadgroup = 8; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_quadgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; + + w += out_row * in_vec_size_w + quad_lid * packs_per_thread; + scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; + x += tid.x * in_vec_size + quad_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_quadgroup; row++) { + auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); + const device T* sl = scales + row * in_vec_size_g * quads_per_simd; + const device T* bl = biases + row * in_vec_size_g * quads_per_simd; + + U s = sl[0]; + U b = bl[0]; + if (row * quads_per_simd + out_row < out_vec_size) { + result[row] += qdot(wl, x_thread, s, b, sum); + } + } + + for (int row = 0; row < results_per_quadgroup; row++) { + result[row] = quad_sum(result[row]); + if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { + y[row * quads_per_simd] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void qmv_fast_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int packs_per_thread = bits == 2 ? 1 : 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +METAL_FUNC void qmv_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int packs_per_thread = 1; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = in_vec_size / group_size; + const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + + simd_gid * results_per_simdgroup; + const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); + + if (out_row >= out_vec_size) { + return; + } + + // In this case we need to properly guard all our reads because there isn't + // even 1 tile in the matrix + if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { + ws += + out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; out_row + row < out_vec_size; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + } + + for (int row = 0; out_row + row < out_vec_size; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } + + // In this case the last tile is moved back to redo some output values + else { + ws += used_out_row * in_vec_size_w + + simd_lid * packs_per_thread * bytes_per_pack; + scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; + x += tid.x * in_vec_size + simd_lid * values_per_thread; + y += tid.x * out_vec_size + used_out_row; + + int k = 0; + for (; k < in_vec_size - block_size; k += block_size) { + U sum = load_vector(x, x_thread); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += + qdot(wl, x_thread, s, b, sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size / group_size; + biases += block_size / group_size; + x += block_size; + } + const int remaining = clamp( + static_cast(in_vec_size - k - simd_lid * values_per_thread), + 0, + values_per_thread); + if (remaining > 0) { + U sum = load_vector_safe( + x, x_thread, remaining); + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + const device T* sl = scales + row * in_vec_size_g; + const device T* bl = biases + row * in_vec_size_g; + + U s = sl[0]; + U b = bl[0]; + result[row] += qdot_safe( + wl, x_thread, s, b, sum, remaining); + } + } + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } + } +} + +template +METAL_FUNC void qvm_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + const int in_vec_size, + const int out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int num_simdgroups = 2; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int tn = 32 / pack_factor; + constexpr int block_size = SIMD_SIZE; + + using W_T = + typename ConditionalType::type; + const device W_T* ws = (const device W_T*)w; + + typedef float U; + typedef struct { + W_T wi[tn * bytes_per_pack]; + } vec_w; + + thread vec_w w_local; + thread U result[tn * pack_factor] = {0}; + thread U scale = 1; + thread U bias = 0; + thread U x_local = 0; + + // Adjust positions + const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; + const int out_vec_size_g = out_vec_size / group_size; + int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); + ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; + scales += out_col / group_size + simd_lid * out_vec_size_g; + biases += out_col / group_size + simd_lid * out_vec_size_g; + x += tid.x * in_vec_size + simd_lid; + y += tid.x * out_vec_size + out_col; + + if (out_col >= out_vec_size) { + return; + } + + // Loop over in_vec in blocks of block_size + int remaining = in_vec_size % block_size; + if (remaining == 0) { + for (int i = 0; i < in_vec_size; i += block_size) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)ws); + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + } else { + for (int i = block_size; i < in_vec_size; i += block_size) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)ws); + + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + + x += block_size; + scales += block_size * out_vec_size_g; + biases += block_size * out_vec_size_g; + ws += block_size * out_vec_size_w; + } + if (static_cast(simd_lid) < remaining) { + x_local = *x; + scale = *scales; + bias = *biases; + w_local = *((device vec_w*)ws); + } else { + x_local = 0; + scale = 0; + bias = 0; + } + qouter( + (thread uint8_t*)&w_local, x_local, scale, bias, result); + } + +// Accumulate in the simdgroup +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + result[k] = simd_sum(result[k]); + } + + // Store the result + if (simd_lid == 0) { +#pragma clang loop unroll(full) + for (int k = 0; k < tn * pack_factor; k++) { + y[k] = static_cast(result[k]); + } + } +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void qmm_t_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + const short num_outs = min(BN, N - y_col); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if (!aligned_N && num_outs < BN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_safe(short2(BK, num_outs)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM || num_outs < BN) { + mma_op.store_result_safe(y, N, short2(num_outs, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +METAL_FUNC void qmm_n_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Xs, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int WM = 2; + constexpr int WN = 2; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + // Instantiate the appropriate BlockMMA and Loader + using mma_t = mlx::steel:: + BlockMMA; + using loader_x_t = mlx::steel:: + BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + auto wl = (const device uint8_t*)w; + + // Set the block + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + x += y_row * static_cast(K); + wl += y_col * bytes_per_pack / pack_factor; + scales += y_col / group_size; + biases += y_col / group_size; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + const short num_els = min(BM, M - y_row); + loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); + loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); + mma_t mma_op(simd_gid, simd_lid); + + if (num_els < BM) { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, num_els)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(BK, num_els)); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } else { + if ((K % BK) != 0) { + const int k_blocks = K / BK; + for (int k = 0; k < k_blocks; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + const short num_k = K - k_blocks * BK; + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_safe(short2(num_k, BM)); + loader_w.load_safe(short2(BN, num_k)); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + } else { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_x.load_unsafe(); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(Xs, Ws); + loader_x.next(); + loader_w.next(); + } + } + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + if (num_els < BM) { + mma_op.store_result_safe(y, N, short2(BN, num_els)); + } else { + mma_op.store_result(y, N); + } +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +[[kernel]] void affine_qmv_quad( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint quad_gid [[quadgroup_index_in_threadgroup]], + uint quad_lid [[thread_index_in_quadgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_quad_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + quad_gid, + quad_lid); +} + +template +[[kernel]] void affine_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_qvm( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + if (batched) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_qvm_split_k( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + const constant int& x_batch_ndims [[buffer(7)]], + const constant int* x_shape [[buffer(8)]], + const constant int64_t* x_strides [[buffer(9)]], + const constant int& w_batch_ndims [[buffer(10)]], + const constant int* w_shape [[buffer(11)]], + const constant int64_t* w_strides [[buffer(12)]], + const constant int64_t* s_strides [[buffer(13)]], + const constant int64_t* b_strides [[buffer(14)]], + const constant int& final_block_size [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + out_vec_size * M, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + + // When (in_vec_size % split_k != 0) the final block needs to be smaller + int in_vec_size_adj = + tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; + + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size_adj, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_qmm_t( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_t_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_qmm_n( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + + qmm_n_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template +[[kernel]] void affine_gather_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_fast_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_gather_qmv( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmv_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template +[[kernel]] void affine_gather_qvm( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& in_vec_size [[buffer(7)]], + const constant int& out_vec_size [[buffer(8)]], + const constant int& x_batch_ndims [[buffer(9)]], + const constant int* x_shape [[buffer(10)]], + const constant int64_t* x_strides [[buffer(11)]], + const constant int& w_batch_ndims [[buffer(12)]], + const constant int* w_shape [[buffer(13)]], + const constant int64_t* w_strides [[buffer(14)]], + const constant int64_t* s_strides [[buffer(15)]], + const constant int64_t* b_strides [[buffer(16)]], + const constant int& batch_ndims [[buffer(17)]], + const constant int* batch_shape [[buffer(18)]], + const constant int64_t* lhs_strides [[buffer(19)]], + const constant int64_t* rhs_strides [[buffer(20)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + int M = x_shape[x_batch_ndims]; + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + out_vec_size * M, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qvm_impl( + w, + scales, + biases, + x, + y, + in_vec_size, + out_vec_size, + tid, + simd_gid, + simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_gather_qmm_t( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 32, + const int BK = 32, + const int BN = 32> +[[kernel]] void affine_gather_qmm_n( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_impl( + w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void affine_gather_qmm_rhs( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using mma_t = mlx::steel::BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + false, + transpose, + BK_padded, + transpose ? BK_padded : BN_padded>; + using loader_x_t = + mlx::steel::BlockLoader; + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Xs[BM * BK_padded]; + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_x = short2(k_remain, tgp_bm); + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + // Matrices are all aligned check nothing + if (align_M && align_N) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } else { + // Tile aligned so check outside of the hot loop + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(y, N); + } else { + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_loop_unaligned( + Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + gemm_loop_finalize( + Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); + } + mma_op.store_result_slice( + y, N, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template +[[kernel]] void affine_quantize( + const device T* w [[buffer(0)]], + device uint8_t* out [[buffer(1)]], + device T* scales [[buffer(2)]], + device T* biases [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr float eps = 1e-7; + constexpr int simd_size = 32; + constexpr float n_bins = (1 << bits) - 1; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_reduce = group_size / simd_size; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; + constexpr int writes_per_pack = + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + + static_assert( + group_size % simd_size == 0, + "Group size must be divisible by simd size."); + + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t in_index = offset * values_per_reduce; + size_t out_index = power_of_2_bits + ? offset * writes_per_pack + : offset * bytes_per_pack / writes_per_reduce; + + float w_thread[values_per_reduce]; + float w_min = Limits::max; + float w_max = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + float val = w[in_index + i]; + w_thread[i] = val; + w_min = min(w_min, val); + w_max = max(w_max, val); + } + + w_min = simd_min(w_min); + w_max = simd_max(w_max); + + float scale = max((w_max - w_min) / n_bins, eps); + bool side = abs(w_min) > abs(w_max); + scale = side ? scale : -scale; + float edge = side ? w_min : w_max; + float q0 = round(edge / scale); + bool at_zero = q0 == 0.0f; + scale = at_zero ? scale : edge / q0; + float bias = at_zero ? 0 : edge; + + // Write out the scales and biases + size_t gindex = in_index / group_size; + if (in_index % group_size == 0) { + scales[gindex] = static_cast(scale); + biases[gindex] = static_cast(bias); + } + + using OutType = metal::conditional_t; + OutType output = 0; + +#pragma clang loop unroll(full) + for (int i = 0; i < values_per_reduce; i++) { + uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); + if (bits == 8) { + output = val; + } else { + output |= val << (bits * (i % pack_factor)); + } + + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; + output = 0; + } else { +#pragma clang loop unroll(full) + for (int j = 1; j < writes_per_reduce; j++) { + uint8_t sval = simd_shuffle_down(val, j); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); + } + } + } + if (bits == 3 || bits == 6) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } + } else { + if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { + out[out_index / writes_per_reduce] = output; + } + } +} + +template +[[kernel]] void affine_dequantize( + const device uint8_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + device T* out [[buffer(3)]], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + size_t offset = index.x + grid_dim.x * size_t(index.y); + size_t oindex = offset * pack_factor; + size_t gindex = oindex / group_size; + T scale = scales[gindex]; + T bias = biases[gindex]; + + out += oindex; + + if (bits == 3) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x7) * scale + bias; + out[1] = ((w[0] & 0x38) >> 3) * scale + bias; + out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + out[3] = ((w[1] & 0xe) >> 1) * scale + bias; + out[4] = ((w[1] & 0x70) >> 4) * scale + bias; + out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } else if (bits == 6) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x3f) * scale + bias; + out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } else { + uint val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 2) { + d = (val >> (bits * i)) & 0x03; + } else if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = scale * d + bias; + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/quantized_nax.h b/dist/include/mlx/backend/metal/kernels/quantized_nax.h new file mode 100644 index 0000000..c26ff64 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/quantized_nax.h @@ -0,0 +1,1705 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +using namespace metal; +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +using namespace metal; + +#define MLX_MTL_CONST static constant constexpr const + +MLX_MTL_CONST int SIMD_SIZE = 32; +MLX_MTL_CONST int QUAD_SIZE = 4; + +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + +template +inline U load_vector(const device T* x, thread U* x_thread) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < values_per_thread; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + return sum; +} + +template +inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U sum = 0; + + if (bits == 2) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 4.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 64.0f; + } + } + + else if (bits == 3) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 8.0f; + x_thread[i + 2] = x[i + 2] / 64.0f; + x_thread[i + 3] = x[i + 3] / 2.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 128.0f; + x_thread[i + 6] = x[i + 6] / 4.0f; + x_thread[i + 7] = x[i + 7] / 32.0f; + } + } + + else if (bits == 4) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } + + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + + else if (bits == 6) { + for (int i = 0; i < N; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 64.0f; + x_thread[i + 2] = x[i + 2] / 16.0f; + x_thread[i + 3] = x[i + 3] / 4.0f; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + sum += x[i]; + x_thread[i] = x[i]; + } + } + + for (int i = N; i < values_per_thread; i++) { + x_thread[i] = 0; + } + + return sum; +} + +template +inline U qdot( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (values_per_thread / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline U qdot_safe( + const device uint8_t* w, + const thread U* x_thread, + U scale, + U bias, + U sum, + int N) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + U accum = 0; + + if (bits == 2) { + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (w[i] & 0x03) + + x_thread[4 * i + 1] * (w[i] & 0x0c) + + x_thread[4 * i + 2] * (w[i] & 0x30) + + x_thread[4 * i + 3] * (w[i] & 0xc0)); + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 3 * i; + + accum += (w[0] & 0x07) * x_thread[0]; + accum += (w[0] & 0x38) * x_thread[1]; + accum += (w[0] & 0xc0) * x_thread[2]; + accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); + + accum += (w[1] & 0x0e) * x_thread[3]; + accum += (w[1] & 0x70) * x_thread[4]; + accum += (w[1] & 0x80) * x_thread[5]; + accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); + + accum += (w[2] & 0x1c) * x_thread[6]; + accum += (w[2] & 0xe0) * x_thread[7]; + } + } + + else if (bits == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (N / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + x_thread += 4 * i; + w += 3 * i; + + accum += (w[0] & 0x3f) * x_thread[0]; + + accum += (w[0] & 0xc0) * x_thread[1]; + accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); + + accum += (w[1] & 0xf0) * x_thread[2]; + accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); + + accum += (w[2] & 0xfc) * x_thread[3]; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + accum += x_thread[i] * w[i]; + } + } + + return scale * accum + sum * bias; +} + +template +inline void +qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; + for (int i = 0; i < (values_per_thread / 4); i++) { + result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); + result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); + result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); + result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); + } + } + + else if (bits == 3) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[8 * i] += x * ((w0 & 0x7) * scale + bias); + result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); + result[8 * i + 2] += + x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); + result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); + result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); + result[8 * i + 5] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); + result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); + result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / 16.0f}; + for (int i = 0; i < (values_per_thread / 2); i++) { + result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); + result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); + } + } + + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { + for (int i = 0; i < (values_per_thread / 4); i++) { + uint8_t w0 = w[3 * i]; + uint8_t w1 = w[3 * i + 1]; + uint8_t w2 = w[3 * i + 2]; + + result[4 * i] += x * ((w0 & 0x3f) * scale + bias); + result[4 * i + 1] += + x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); + result[4 * i + 2] += + x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); + result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); + } + } + + else if (bits == 8) { + for (int i = 0; i < values_per_thread; i++) { + result[i] += x * (scale * w[i] + bias); + } + } +} + +template +inline void +dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + if (bits == 2) { + U s[4] = { + scale, + scale / static_cast(4.0f), + scale / static_cast(16.0f), + scale / static_cast(64.0f)}; + for (int i = 0; i < (N / 4); i++) { + w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; + w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; + w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; + w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; + } + } + + else if (bits == 3) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 3 * i; + + w_local[0] = (w[0] & 0x7) * scale + bias; + w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; + w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; + w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; + w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; + w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; + w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; + w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } + } + + else if (bits == 4) { + U s[2] = {scale, scale / static_cast(16.0f)}; + for (int i = 0; i < (N / 2); i++) { + w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; + w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; + } + } + + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + + else if (bits == 6) { + for (int i = 0; i < (N / 4); i++) { + w_local += 4 * i; + w += 3 * i; + w_local[0] = (w[0] & 0x3f) * scale + bias; + w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; + w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; + w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; + } + } + + else if (bits == 8) { + for (int i = 0; i < N; i++) { + w_local[i] = scale * w[i] + bias; + } + } +} + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short group_size, + short bits> +struct QuantizedBlockLoader { + static_assert( + BCOLS <= group_size, + "The group size should be larger than the columns"); + static_assert( + group_size % BCOLS == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short group_steps = group_size / BCOLS; + + const int src_ld; + const int tile_stride; + short group_step_cnt; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_step_cnt(0), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size), + biases(biases_ + bi * src_ld / group_size) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + if (group_steps > 1) { + group_step_cnt++; + if (group_step_cnt == group_steps) { + group_step_cnt = 0; + scales++; + biases++; + } + } else { + scales++; + biases++; + } + } else { + scales += group_stride; + biases += group_stride; + } + } +}; + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short bits> +struct QuantizedBlockLoader< + T, + BROWS, + BCOLS, + dst_ld, + reduction_dim, + tgp_size, + 32, + bits> { + MLX_MTL_CONST short group_size = 32; + + static_assert( + BCOLS % group_size == 0, + "The group size should be divisible by the columns"); + static_assert( + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); + MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; + MLX_MTL_CONST short n_reads = + (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; + MLX_MTL_CONST short n_groups = BCOLS / group_size; + + static_assert( + (BCOLS_PACKED / n_reads) == n_groups, + "Other configurations are not yet supported"); + + const int src_ld; + const int tile_stride; + const int group_stride; + + const short thread_idx; + const short bi; + const short bj; + + const short group_id; + + threadgroup T* dst; + const device uint8_t* src; + const device T* scales; + const device T* biases; + + QuantizedBlockLoader( + const device uint8_t* src_, + const device T* scales_, + const device T* biases_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride( + reduction_dim ? BCOLS_PACKED * bytes_per_pack + : BROWS * src_ld * bytes_per_pack / pack_factor), + group_stride(BROWS * src_ld / group_size), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(n_reads * thread_idx / BCOLS_PACKED), + bj((n_reads * thread_idx) % BCOLS_PACKED), + group_id((bj * pack_factor) / group_size), + dst(dst_ + bi * dst_ld + bj * pack_factor), + src(src_ + bi * src_ld * bytes_per_pack / pack_factor + + bj * bytes_per_pack), + scales(scales_ + bi * src_ld / group_size + group_id), + biases(biases_ + bi * src_ld / group_size + group_id) {} + + void load_unsafe() const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); + } + } + + void load_safe(short2 src_tile_dim) const { + if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { + return; + } + + if (reduction_dim == 1 && bi >= src_tile_dim.x) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + if (reduction_dim == 0 && bi >= src_tile_dim.y) { + for (int i = 0; i < n_reads * pack_factor; i++) { + dst[i] = T(0); + } + return; + } + + T scale = *scales; + T bias = *biases; + for (int i = 0; i < n_reads; i++) { + dequantize( + (device uint8_t*)(src + i * bytes_per_pack), + scale, + bias, + dst + i * pack_factor); + } + } + + void next() { + src += tile_stride; + if (reduction_dim == 1) { + // if (group_steps > 1) { + // group_step_cnt++; + // if (group_step_cnt == group_steps) { + // group_step_cnt = 0; + // scales++; + // biases++; + // } + // } else { + scales += n_groups; + biases += n_groups; + // } + } else { + scales += n_groups * group_stride; + biases += n_groups * group_stride; + } + } +}; + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + device T*& y, + int output_stride, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx = tid.z; + uint32_t w_idx = tid.z; + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template +METAL_FUNC void adjust_matrix_offsets( + const device T*& x, + const device uint32_t*& w, + const device T*& scales, + const device T*& biases, + const device uint32_t* lhs_indices, + const device uint32_t* rhs_indices, + device T*& y, + int output_stride, + const constant int& batch_ndims, + const constant int* batch_shape, + const constant int64_t* lhs_strides, + const constant int64_t* rhs_strides, + const constant int& x_batch_ndims, + const constant int* x_shape, + const constant int64_t* x_strides, + const constant int& w_batch_ndims, + const constant int* w_shape, + const constant int64_t* w_strides, + const constant int64_t* s_strides, + const constant int64_t* b_strides, + uint3 tid [[threadgroup_position_in_grid]]) { + // Set the input/output matrices + uint32_t x_idx; + uint32_t w_idx; + if (batch_ndims == 1) { + x_idx = lhs_indices[tid.z * lhs_strides[0]]; + w_idx = rhs_indices[tid.z * rhs_strides[0]]; + } else { + ulong2 idx = elem_to_loc_broadcast( + tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); + x_idx = lhs_indices[idx.x]; + w_idx = rhs_indices[idx.y]; + } + if (x_batch_ndims == 1) { + x += x_idx * x_strides[0]; + } else { + x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); + } + if (w_batch_ndims == 1) { + w += w_idx * w_strides[0]; + scales += w_idx * s_strides[0]; + biases += w_idx * b_strides[0]; + } else { + ulong3 idx = elem_to_loc_broadcast( + w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); + w += idx.x; + scales += idx.y; + biases += idx.z; + } + y += tid.z * output_stride; +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_t_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + (void)lid; + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BN, + BK, + BK_padded, + 1, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the weight loader + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + + const short sgp_sm = min(SM, short(M - (y_row + tm))); + const bool is_unaligned_sm = (sgp_sm != SM); + + const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); + + const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); + const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe(short2(BK, tgp_bn)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(x + kk1, K); + } else { + Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); + } + + Btile.template load(Ws + tn * BK_padded + kk1); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + if constexpr (kAlignedM.value && kAlignedN.value) { + Dtile.store(y + tm * N + tn, N); + } else if (kAlignedM.value && sgp_sn == SN) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); + } + }); + }); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +METAL_FUNC void qmm_n_nax_tgp_impl( + const device uint32_t* w, + const device T* scales, + const device T* biases, + const device T* x, + device T* y, + threadgroup T* Ws, + const constant int& K, + const constant int& N, + const constant int& M, + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + (void)M; + + static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); + static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); + + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + BK, + BN, + BN_padded, + 0, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + // Set the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + + auto wl = (const device uint8_t*)w; + + x += y_row * static_cast(K); + wl += y_col * K_w; + scales += y_col * K_g; + biases += y_col * K_g; + y += y_row * static_cast(N) + y_col; + + // Make the x loader and mma operation + // const short num_els = min(BM, M - y_row); + // const short num_outs = min(BN, N - y_col); + loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_gid / WN); + const short tn = SN * (simd_gid % WN); + + const short ldb_tgp = BN_padded; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = false; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + NAXTile Dtile; + + Dtile.clear(); + + x += tm * K; + + for (int k = 0; k < K; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + Atile.load(x + kk1, K); + Btile.template load(Ws + tn + kk1 * ldb_tgp); + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + x += BK; + loader_w.next(); + } + + // Store results to device memory + threadgroup_barrier(mem_flags::mem_threadgroup); + + Dtile.store(y + tm * N + tn, N); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const bool batched, + const int BM = 64, + const int BK = 32, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool batched, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + device T* y [[buffer(4)]], + const constant int& K [[buffer(5)]], + const constant int& N [[buffer(6)]], + const constant int& M [[buffer(7)]], + const constant int& x_batch_ndims [[buffer(8)]], + const constant int* x_shape [[buffer(9)]], + const constant int64_t* x_strides [[buffer(10)]], + const constant int& w_batch_ndims [[buffer(11)]], + const constant int* w_shape [[buffer(12)]], + const constant int64_t* w_strides [[buffer(13)]], + const constant int64_t* s_strides [[buffer(14)]], + const constant int64_t* b_strides [[buffer(15)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + if (batched) { + adjust_matrix_offsets( + x, + w, + scales, + biases, + y, + M * N, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + } + + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const bool aligned_N, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_t_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BK_padded = (BK + 16 / sizeof(T)); + + threadgroup T Ws[BN * BK_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_t_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + const int group_size, + const int bits, + const int BM = 64, + const int BK = 64, + const int BN = 64, + const int WM = 2, + const int WN = 2> +[[kernel]] void affine_gather_qmm_n_nax( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* biases [[buffer(2)]], + const device T* x [[buffer(3)]], + const device uint32_t* lhs_indices [[buffer(4)]], + const device uint32_t* rhs_indices [[buffer(5)]], + device T* y [[buffer(6)]], + const constant int& K [[buffer(7)]], + const constant int& N [[buffer(8)]], + const constant int& M [[buffer(9)]], + const constant int& x_batch_ndims [[buffer(10)]], + const constant int* x_shape [[buffer(11)]], + const constant int64_t* x_strides [[buffer(12)]], + const constant int& w_batch_ndims [[buffer(13)]], + const constant int* w_shape [[buffer(14)]], + const constant int64_t* w_strides [[buffer(15)]], + const constant int64_t* s_strides [[buffer(16)]], + const constant int64_t* b_strides [[buffer(17)]], + const constant int& batch_ndims [[buffer(18)]], + const constant int* batch_shape [[buffer(19)]], + const constant int64_t* lhs_strides [[buffer(20)]], + const constant int64_t* rhs_strides [[buffer(21)]], + uint3 tid [[threadgroup_position_in_grid]], + uint lid [[thread_index_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + threadgroup T Ws[BK * BN_padded]; + + adjust_matrix_offsets( + x, + w, + scales, + biases, + lhs_indices, + rhs_indices, + y, + M * N, + batch_ndims, + batch_shape, + lhs_strides, + rhs_strides, + x_batch_ndims, + x_shape, + x_strides, + w_batch_ndims, + w_shape, + w_strides, + s_strides, + b_strides, + tid); + qmm_n_nax_tgp_impl( + w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); +} + +template < + typename T, + int group_size, + int bits, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose> +[[kernel]] void affine_gather_qmm_rhs_nax( + const device T* x [[buffer(0)]], + const device uint32_t* w [[buffer(1)]], + const device T* scales [[buffer(2)]], + const device T* biases [[buffer(3)]], + const device uint32_t* indices [[buffer(4)]], + device T* y [[buffer(5)]], + const constant int& M [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant int& K [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) { + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); + constexpr int BN_padded = (BN + 16 / sizeof(T)); + + using loader_w_t = QuantizedBlockLoader< + T, + transpose ? BN : BK, + transpose ? BK : BN, + transpose ? BK_padded : BN_padded, + transpose, + WM * WN * SIMD_SIZE, + group_size, + bits>; + + threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; + + // Compute the block + const int K_w = K * bytes_per_pack / pack_factor; + const int K_g = K / group_size; + const int N_w = N * bytes_per_pack / pack_factor; + const int N_g = N / group_size; + const int K_it = K / BK; + const size_t stride_w = transpose ? N * K_w : K * N_w; + const size_t stride_s = transpose ? N * K_g : K * N_g; + const int y_row = tid.y * BM; + const int y_col = tid.x * BN; + const size_t y_row_long = size_t(y_row); + const size_t y_col_long = size_t(y_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); + const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); + + // Calculate the final tiles in the case that K is not aligned + const int k_remain = K - K_it * BK; + const short2 tile_w = + transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + // Move x and output to the correct block + auto wl = (const device uint8_t*)w; + x += y_row_long * K; + y += y_row_long * N + y_col_long; + wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; + scales += transpose ? y_col_long * K_g : y_col / group_size; + biases += transpose ? y_col_long * K_g : y_col / group_size; + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = + align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); + const short sgp_sn = + align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); + + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); + + constexpr short BR = transpose ? TN : TK; + constexpr short BC = transpose ? TK : TN; + + using AccumType = float; + + using ASubTile = NAXSubTile; + using BSubTile = NAXSubTile; + using DSubTile = NAXSubTile; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = indices[y_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (indices[y_row + n] != index) { + offset_next = n; + index_next = indices[y_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + NAXTile Dtile; + + Dtile.clear(); + + const device T* xn = x + tm * K; + + // Prepare threadgroup loading operations + thread loader_w_t loader_w( + wl + index * stride_w, + scales + index * stride_s, + biases + index * stride_s, + transpose ? K : N, + Ws, + simd_group_id, + simd_lane_id); + + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { + for (int k = 0; k < K_it; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if constexpr (kAlignedN.value) { + loader_w.load_unsafe(); + } else { + loader_w.load_safe( + transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + if constexpr (kAlignedM.value) { + Atile.load(xn + kk1, K); + } else { + Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); + } + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + xn += BK; + loader_w.next(); + } + + if (!align_K) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_w.load_safe(tile_w); + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + volatile int compiler_barrier; + + const short psk = min(int(SK), max(0, (BK - kk1))); + Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); + + if constexpr (transpose) { + Btile.template load(Ws + tn * BK_padded + kk1); + } else { + Btile.template load(Ws + tn + kk1 * BN_padded); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); + const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); + + // Store results to device memory + if constexpr (kAlignedN.value) { + if (m_lo_lim == 0 && m_hi_lim == SM) { + Dtile.store(y + tm * N + tn, N); + } else { + Dtile.store_slice( + y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); + } + } else { + Dtile.store_slice( + y + tm * N + tn, + N, + short2(0, m_lo_lim), + short2(sgp_sn, m_hi_lim)); + } + }); + }); + } +} \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/quantized_utils.h b/dist/include/mlx/backend/metal/kernels/quantized_utils.h new file mode 100644 index 0000000..38253f8 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/quantized_utils.h @@ -0,0 +1,90 @@ +// Copyright © 2023-2024 Apple Inc. + +#include +#include + +template +METAL_FUNC void gemm_loop_aligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template < + bool rows_aligned, + bool cols_aligned, + bool transpose, + typename T, + typename mma_t, + typename loader_a_t, + typename loader_b_t> +METAL_FUNC void gemm_loop_unaligned( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const int k_iterations, + const short tgp_bm, + const short tgp_bn, + const short tgp_bk) { + for (int k = 0; k < k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup memory + if (rows_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(short2(tgp_bk, tgp_bm)); + } + if (cols_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe( + transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } +} + +template +METAL_FUNC void gemm_loop_finalize( + threadgroup T* As, + threadgroup T* Bs, + thread mma_t& mma_op, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + const short2 tile_a, + const short2 tile_b) { + loader_a.load_safe(tile_a); + loader_b.load_safe(tile_b); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); +} diff --git a/dist/include/mlx/backend/metal/kernels/reduce.h b/dist/include/mlx/backend/metal/kernels/reduce.h new file mode 100644 index 0000000..ee5c3d5 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/reduce.h @@ -0,0 +1,5 @@ +#pragma once +#include "mlx/backend/metal/kernels/reduction/reduce_all.h" +#include "mlx/backend/metal/kernels/reduction/reduce_col.h" +#include "mlx/backend/metal/kernels/reduction/reduce_init.h" +#include "mlx/backend/metal/kernels/reduction/reduce_row.h" diff --git a/dist/include/mlx/backend/metal/kernels/reduce_utils.h b/dist/include/mlx/backend/metal/kernels/reduce_utils.h new file mode 100644 index 0000000..279a7af --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/reduce_utils.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/atomic.h" +#include "mlx/backend/metal/kernels/reduction/ops.h" diff --git a/dist/include/mlx/backend/metal/kernels/reduction/ops.h b/dist/include/mlx/backend/metal/kernels/reduction/ops.h new file mode 100644 index 0000000..11d8e83 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/reduction/ops.h @@ -0,0 +1,275 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#define DEFINE_SIMD_REDUCE() \ + template = true> \ + T simd_reduce(T val) { \ + return simd_reduce_impl(val); \ + } \ + \ + template = true> \ + T simd_reduce(T val) { \ + for (short i = simd_size / 2; i > 0; i /= 2) { \ + val = operator()(val, simd_shuffle_down(val, i)); \ + } \ + return val; \ + } + +static constant constexpr const uint8_t simd_size = 32; + +union bool4_or_uint { + bool4 b; + unsigned int i; +}; + +struct None { + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_store_explicit(out, val, offset); + } +}; + +template +struct And { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { + return simd_all(val); + } + + static constexpr constant bool init = true; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + size_t offset = 0) { + if (!val) { + bool4_or_uint update; + update.b = {true, true, true, true}; + update.b[elem_idx] = false; + mlx_atomic_fetch_and_explicit(out, update.i, offset); + } + } + + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { + if (!val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out &= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a && b; + } +}; + +template +struct Or { + DEFINE_SIMD_REDUCE() + + bool simd_reduce_impl(bool val) { + return simd_any(val); + } + + static constexpr constant bool init = false; + + void atomic_update( + device mlx_atomic* out, + bool val, + int elem_idx, + size_t offset = 0) { + if (val) { + bool4_or_uint update; + update.b = {false, false, false, false}; + update.b[elem_idx] = true; + mlx_atomic_fetch_or_explicit(out, update.i, offset); + } + } + + void + atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { + if (val) { + mlx_atomic_store_explicit(out, val, offset); + } + } + + // Non atomic update + void update(device bool* out, bool val) { + *out |= val; + } + + // Operator + bool operator()(bool a, bool b) { + return a || b; + } +}; + +template +struct Sum { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_sum(val); + } + + static constexpr constant U init = U(0); + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_add_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a + b; + } +}; + +template +struct Prod { + DEFINE_SIMD_REDUCE() + + template + T simd_reduce_impl(T val) { + return simd_product(val); + } + + static constexpr constant U init = U(1); + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_mul_explicit(out, val, offset); + } + + // Operator + U operator()(U a, U b) { + return a * b; + } +}; + +template +struct Min { + DEFINE_SIMD_REDUCE() + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_min(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } + return simd_min(val); + } + + static constexpr constant U init = Limits::max; + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_min_explicit(out, val, offset); + } + + // Operator + template + metal::enable_if_t, T> operator()(T a, T b) { + return a < b ? a : b; + } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a < b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a < b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real < b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + }; +}; +template +struct Max { + DEFINE_SIMD_REDUCE() + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + return simd_max(val); + } + + template + metal::enable_if_t, T> simd_reduce_impl(T val) { + if (simd_any(val != val)) { + return static_cast(NAN); + } + return simd_max(val); + } + + static constexpr constant U init = Limits::min; + + template + void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { + mlx_atomic_fetch_max_explicit(out, val, offset); + } + + // Operator + template + metal::enable_if_t, T> operator()(T a, T b) { + return a > b ? a : b; + } + + template + metal::enable_if_t, T> operator()(T a, T b) { + if (metal::isnan(a) || metal::isnan(b)) { + return static_cast(NAN); + } else { + return a > b ? a : b; + } + } + + template <> + complex64_t operator()(complex64_t a, complex64_t b) { + bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); + bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); + + if (!real_is_nan && !imag_is_nan) { + return a > b ? a : b; + } else if (real_is_nan && !imag_is_nan) { + return complex64_t( + static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); + } else if (!real_is_nan && imag_is_nan) { + return complex64_t( + a.real > b.real ? a.real : b.real, static_cast(NAN)); + } else { + return complex64_t(static_cast(NAN), static_cast(NAN)); + } + } +}; diff --git a/dist/include/mlx/backend/metal/kernels/reduction/reduce_all.h b/dist/include/mlx/backend/metal/kernels/reduction/reduce_all.h new file mode 100644 index 0000000..e0d0839 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/reduction/reduce_all.h @@ -0,0 +1,66 @@ +// Copyright © 2023-2024 Apple Inc. + +template < + typename T, + typename U, + typename Op, + typename IdxT = int64_t, + int N_READS = REDUCE_N_READS> +[[kernel]] void all_reduce( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& in_size [[buffer(2)]], + const constant size_t& row_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + threadgroup U shared_vals[simd_size]; + + U total = Op::init; + IdxT start_idx = gid.y * IdxT(row_size); + IdxT actual_row = + (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; + IdxT blocks = actual_row / (lsize.x * N_READS); + int extra = actual_row - blocks * (lsize.x * N_READS); + extra -= lid.x * N_READS; + start_idx += lid.x * N_READS; + in += start_idx; + + if (extra >= N_READS) { + blocks++; + extra = 0; + } + + for (IdxT b = 0; b < blocks; b++) { + for (int i = 0; i < N_READS; i++) { + total = op(static_cast(in[i]), total); + } + in += lsize.x * N_READS; + } + if (extra > 0) { + for (int i = 0; i < extra; i++) { + total = op(static_cast(in[i]), total); + } + } + + // Reduction within simd group + total = op.simd_reduce(total); + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + shared_vals[simd_group_id] = total; + } + + // Reduction within thread group + threadgroup_barrier(mem_flags::mem_threadgroup); + total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init; + total = op.simd_reduce(total); + } + + if (lid.x == 0) { + out[gid.y] = total; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/reduction/reduce_col.h b/dist/include/mlx/backend/metal/kernels/reduction/reduce_col.h new file mode 100644 index 0000000..c109faf --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/reduction/reduce_col.h @@ -0,0 +1,398 @@ +// Copyright © 2023-2024 Apple Inc. + +template +[[kernel]] void col_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + constexpr int n_reads = 4; + Op op; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + U totals[n_reads]; + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; + if (column >= reduction_stride) { + return; + } + bool safe = column + n_reads <= reduction_stride; + + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(lid.y, reduce_shape, reduce_strides); + for (IdxT r = lid.y; r < total_rows; r += lsize.y) { + row = in + loop.location(); + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + loop.next(lsize.y, reduce_shape, reduce_strides); + } + + if (lsize.y > 1) { + // lsize.y should be <= 8 + threadgroup U shared_vals[32 * 8 * n_reads]; + for (int i = 0; i < n_reads; i++) { + shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (int i = 0; i < n_reads; i++) { + totals[i] = shared_vals[lid.x * n_reads + i]; + } + for (uint j = 1; j < lsize.y; j++) { + for (int i = 0; i < n_reads; i++) { + totals[i] = + op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i], + totals[i]); + } + } + } + } + + if (lid.y == 0) { + out += out_idx * IdxT(reduction_stride) + column; + if (safe) { + for (int i = 0; i < n_reads; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } +} + +template +[[kernel]] void col_reduce_longcolumn( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant size_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]]) { + Op op; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + lid.x; + + U total = Op::init; + IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); + for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; + r += lsize.y * gsize.z) { + row = in + loop.location(); + total = op(static_cast(*row), total); + loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); + } + + threadgroup U shared_vals[32 * 32]; + shared_vals[lid.y * lsize.x + lid.x] = total; + threadgroup_barrier(mem_flags::mem_threadgroup); + if (lid.y == 0) { + for (uint i = 1; i < lsize.y; i++) { + total = op(total, shared_vals[i * lsize.x + lid.x]); + } + out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = + total; + } +} + +/** + * Our approach is the following simple looped approach: + * 1. Each thread keeps running totals for BN / n_simdgroups outputs. + * 2. Load a tile BM, BN in registers and accumulate in the running totals + * 3. Move ahead by BM steps until the column axis and the non column + * reductions are exhausted. + * 6. If BM == 32 then transpose in SM and simd reduce the running totals. + * Otherwise write in shared memory and BN threads accumulate the running + * totals with a loop. + * 7. Write them to the output + */ +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> +[[kernel]] void col_reduce_looped( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 8; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + IdxT column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(offset.y, reduce_shape, reduce_strides); + for (IdxT r = offset.y; r < total; r += BM) { + row = in + loop.location(); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + if (BM == 32) { + constexpr int n_outputs = BN / n_simdgroups; + static_assert( + BM != 32 || n_outputs == n_reads, + "The tile should be selected such that n_outputs == n_reads"); + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + IdxT out_column = BN * gid.x + out_offset.x; + out += out_idx * IdxT(reduction_stride) + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } + } + + // Each thread holds n_reads partial results. We write them all out to shared + // memory and threads with offset.y == 0 aggregate the columns and write the + // outputs. + else { + short x_block = offset.x / n_reads; + for (int i = 0; i < n_reads; i++) { + shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (offset.y == 0) { + for (int i = 0; i < n_reads; i++) { + for (int j = 1; j < BM; j++) { + totals[i] = + op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]); + } + } + } + + // Write the output. + if (offset.y == 0) { + out += out_idx * IdxT(reduction_stride) + column; + if (safe) { + for (int i = 0; i < n_reads; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } + } +} + +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int BM, + int BN> +[[kernel]] void col_reduce_2pass( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& reduction_stride [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + const constant size_t& non_col_reductions [[buffer(10)]], + const constant size_t& out_size [[buffer(11)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + constexpr int n_simdgroups = 8; + constexpr short tgp_size = n_simdgroups * simd_size; + constexpr short n_reads = (BM * BN) / tgp_size; + constexpr short n_read_blocks = BN / n_reads; + constexpr int n_outputs = BN / n_simdgroups; + constexpr short outer_blocks = 32; + static_assert(BM == 32, "BM should be equal to 32"); + + threadgroup U shared_vals[BN * BM]; + U totals[n_reads]; + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + + for (int i = 0; i < n_reads; i++) { + totals[i] = Op::init; + } + + short lid = simd_group_id * simd_size + simd_lane_id; + short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); + IdxT column = BN * gid.x + offset.x; + bool safe = column + n_reads <= reduction_stride; + + IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); + IdxT block_idx = full_idx / IdxT(out_size); + IdxT out_idx = full_idx % IdxT(out_size); + IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); + in += in_idx + column; + + IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); + loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); + for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { + row = in + loop.location(); + + if (safe) { + for (int i = 0; i < n_reads; i++) { + totals[i] = op(static_cast(row[i]), totals[i]); + } + } else { + U vals[n_reads]; + for (int i = 0; i < n_reads; i++) { + vals[i] = + (column + i < reduction_stride) ? static_cast(row[i]) : op.init; + } + for (int i = 0; i < n_reads; i++) { + totals[i] = op(vals[i], totals[i]); + } + } + + loop.next(outer_blocks * BM, reduce_shape, reduce_strides); + } + + // We can use a simd reduction to accumulate across BM so each thread writes + // the partial output to SM and then each simdgroup does BN / n_simdgroups + // accumulations. + for (int i = 0; i < n_reads; i++) { + shared_vals[offset.y * BN + offset.x + i] = totals[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + short2 out_offset(simd_group_id * n_outputs, simd_lane_id); + for (int i = 0; i < n_outputs; i++) { + totals[i] = + op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); + } + + // Write the output. + if (simd_lane_id == 0) { + IdxT out_column = BN * gid.x + out_offset.x; + out += full_idx * IdxT(reduction_stride) + out_column; + if (out_column + n_outputs <= reduction_stride) { + for (int i = 0; i < n_outputs; i++) { + out[i] = totals[i]; + } + } else { + for (int i = 0; out_column + i < reduction_stride; i++) { + out[i] = totals[i]; + } + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/reduction/reduce_init.h b/dist/include/mlx/backend/metal/kernels/reduction/reduce_init.h new file mode 100644 index 0000000..604efa7 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/reduction/reduce_init.h @@ -0,0 +1,8 @@ +// Copyright © 2023-2024 Apple Inc. + +template +[[kernel]] void init_reduce( + device T* out [[buffer(0)]], + uint tid [[thread_position_in_grid]]) { + out[tid] = Op::init; +} diff --git a/dist/include/mlx/backend/metal/kernels/reduction/reduce_row.h b/dist/include/mlx/backend/metal/kernels/reduction/reduce_row.h new file mode 100644 index 0000000..936d75b --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/reduction/reduce_row.h @@ -0,0 +1,369 @@ +// Copyright © 2023-2024 Apple Inc. + +// Row reduction utilities +// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup +// - `threadgroup_reduce` collaborative reduction in the threadgroup such that +// lid.x == 0 holds the reduced value +// - `thread_reduce` simple loop and reduce the row + +/** + * The thread group collaboratively reduces across the rows with bounds + * checking. In the end each thread holds a part of the reduction. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* inputs[N_WRITES], + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + Op op; + + // Set up the accumulator registers + for (int i = 0; i < N_WRITES; i++) { + totals[i] = Op::init; + } + + // Loop over the reduction size within thread group + for (int i = 0; i < blocks; i++) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + + inputs[j] += lsize_x * N_READS; + } + } + + // Separate case for the last set as we close the reduction size + int index = lid_x * N_READS; + if (index + N_READS <= extra) { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; i < N_READS; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } + } else { + for (int j = 0; j < N_WRITES; j++) { + for (int i = 0; index + i < extra; i++) { + totals[j] = op(static_cast(inputs[j][i]), totals[j]); + } + } + } +} + +/** + * Consecutive rows in a contiguous array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const constant size_t& reduction_size, + int blocks, + int extra, + uint lsize_x, + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + inputs[0] = in + lid_x * N_READS; + for (int i = 1; i < N_READS; i++) { + inputs[i] = inputs[i - 1] + reduction_size; + } + + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} + +/** + * Consecutive rows in an arbitrarily ordered array. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void per_thread_row_reduce( + thread U totals[N_WRITES], + const device T* in, + const int64_t row_idx, + int blocks, + int extra, + const constant int* shape, + const constant int64_t* strides, + const constant int& ndim, + uint lsize_x, + uint lid_x) { + // Set up the input pointers + const device T* inputs[N_WRITES]; + in += lid_x * N_READS; + for (int i = 0; i < N_READS; i++) { + inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim); + } + + per_thread_row_reduce( + totals, inputs, blocks, extra, lsize_x, lid_x); +} + +/** + * Reduce within the threadgroup. + */ +template < + typename T, + typename U, + typename Op, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +METAL_FUNC void threadgroup_reduce( + thread U totals[N_WRITES], + threadgroup U* shared_vals, + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + + // Simdgroup first + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(totals[i]); + } + + // Across simdgroups + if (simd_per_group > 1) { + if (simd_lane_id == 0) { + for (int i = 0; i < N_WRITES; i++) { + shared_vals[simd_group_id * N_WRITES + i] = totals[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + U values[N_WRITES]; + for (int i = 0; i < N_WRITES; i++) { + values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i] + : op.init; + } + + for (int i = 0; i < N_WRITES; i++) { + totals[i] = op.simd_reduce(values[i]); + } + } +} + +template +METAL_FUNC void +thread_reduce(thread U& total, const device T* row, int blocks, int extra) { + Op op; + for (int i = 0; i < blocks; i++) { + U vals[N_READS]; + for (int j = 0; j < N_READS; j++) { + vals[j] = row[j]; + } + for (int j = 0; j < N_READS; j++) { + total = op(vals[j], total); + } + row += N_READS; + } + for (int i = 0; i < extra; i++) { + total = op(*row++, total); + } +} + +// Reduction kernels +// - `row_reduce_small` depending on the non-row reductions and row size it +// either just loops over everything or a simd collaboratively reduces the +// non_row reductions. In the first case one thread is responsible for one +// output on the 2nd one simd is responsible for one output. +// - `row_reduce_simple` simple contiguous row reduction +// - `row_reduce_looped` simply loop and reduce each row for each non-row +// reduction. One threadgroup is responsible for one output. + +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_small( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int64_t& row_size [[buffer(2)]], + const constant int64_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 tid [[thread_position_in_grid]], + uint3 tsize [[threads_per_grid]]) { + Op op; + + U total_val = Op::init; + LoopedElemToLoc 2)> loop(reduce_ndim); + + // Precompute some row reduction numbers + const device T* row; + int blocks = IdxT(row_size) / N_READS; + int extra = IdxT(row_size) % N_READS; + + if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { + // Simple loop over non_row_reductions and reduce the row in the thread. + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); + in += elem_to_loc(out_idx, shape, strides, ndim); + + for (uint r = 0; r < non_row_reductions; r++) { + row = in + loop.location(); + thread_reduce(total_val, row, blocks, extra); + loop.next(reduce_shape, reduce_strides); + } + + out[out_idx] = total_val; + } else { + // Collaboratively reduce over non_row_reductions in the simdgroup. Each + // thread reduces every 32nd row and then a simple simd reduce. + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + in += elem_to_loc(out_idx, shape, strides, ndim); + + loop.next(simd_lane_id, reduce_shape, reduce_strides); + + for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { + row = in + loop.location(); + thread_reduce(total_val, row, blocks, extra); + loop.next(simd_size, reduce_shape, reduce_strides); + } + + total_val = op.simd_reduce(total_val); + + if (simd_lane_id == 0) { + out[out_idx] = total_val; + } + } +} + +template < + typename T, + typename U, + typename Op, + typename IdxT = int64_t, + int N_READS = REDUCE_N_READS, + int N_WRITES = REDUCE_N_WRITES> +[[kernel]] void row_reduce_simple( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& reduction_size [[buffer(2)]], + const constant int64_t& out_size [[buffer(3)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + threadgroup U shared_vals[simd_size * N_WRITES]; + U totals[N_WRITES]; + + // Move to the row + IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); + if (out_idx + N_WRITES > out_size) { + out_idx = out_size - N_WRITES; + } + in += out_idx * IdxT(reduction_size); + out += out_idx; + + // Each thread reduces across the row + int blocks = IdxT(reduction_size) / (lsize.x * N_READS); + int extra = reduction_size - blocks * (lsize.x * N_READS); + per_thread_row_reduce( + totals, in, reduction_size, blocks, extra, lsize.x, lid.x); + + // Reduce across the threadgroup + threadgroup_reduce( + totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output + if (lid.x == 0) { + for (int i = 0; i < N_WRITES; i++) { + out[i] = totals[i]; + } + } +} + +template < + typename T, + typename U, + typename Op, + typename IdxT, + int NDIMS, + int N_READS = REDUCE_N_READS> +[[kernel]] void row_reduce_looped( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int64_t& row_size [[buffer(2)]], + const constant int64_t& non_row_reductions [[buffer(3)]], + const constant int* shape [[buffer(4)]], + const constant int64_t* strides [[buffer(5)]], + const constant int& ndim [[buffer(6)]], + const constant int* reduce_shape [[buffer(7)]], + const constant int64_t* reduce_strides [[buffer(8)]], + const constant int& reduce_ndim [[buffer(9)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_per_group [[simdgroups_per_threadgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + Op op; + threadgroup U shared_vals[simd_size]; + U total = Op::init; + + IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); + + // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it + // needs a small refactor. + in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; + + LoopedElemToLoc 2)> loop(reduce_ndim); + const device T* row; + int blocks = IdxT(row_size) / (lsize.x * N_READS); + int extra = row_size - blocks * (lsize.x * N_READS); + + for (IdxT i = 0; i < non_row_reductions; i++) { + row = in + loop.location(); + + // Each thread reduces across the row + U row_total; + per_thread_row_reduce( + &row_total, &row, blocks, extra, lsize.x, lid.x); + + // Aggregate across rows + total = op(total, row_total); + + loop.next(reduce_shape, reduce_strides); + } + + // Reduce across the threadgroup + threadgroup_reduce( + &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); + + // Write the output + if (lid.x == 0) { + out[out_idx] = total; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/scan.h b/dist/include/mlx/backend/metal/kernels/scan.h new file mode 100644 index 0000000..1668261 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/scan.h @@ -0,0 +1,514 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/binary_ops.h" + +#define DEFINE_SIMD_SCAN() \ + template = true> \ + T simd_scan(T val) { \ + return simd_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_scan(T val) { \ + for (int i = 1; i <= 16; i *= 2) { \ + val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \ + } \ + return val; \ + } + +#define DEFINE_SIMD_EXCLUSIVE_SCAN() \ + template = true> \ + T simd_exclusive_scan(T val) { \ + return simd_exclusive_scan_impl(val); \ + } \ + \ + template = true> \ + T simd_exclusive_scan(T val) { \ + val = simd_scan(val); \ + return simd_shuffle_and_fill_up(val, init, 1); \ + } + +template +struct CumSum { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + + static constexpr constant U init = static_cast(0); + + template + U operator()(U a, T b) { + return a + b; + } + + U simd_scan_impl(U x) { + return simd_prefix_inclusive_sum(x); + } + + U simd_exclusive_scan_impl(U x) { + return simd_prefix_exclusive_sum(x); + } +}; + +template +struct CumProd { + DEFINE_SIMD_SCAN() + DEFINE_SIMD_EXCLUSIVE_SCAN() + + static constexpr constant U init = static_cast(1.0f); + + template + U operator()(U a, T b) { + return a * b; + } + + U simd_scan_impl(U x) { + return simd_prefix_inclusive_product(x); + } + + U simd_exclusive_scan_impl(U x) { + return simd_prefix_exclusive_product(x); + } +}; + +template <> +struct CumProd { + static constexpr constant bool init = true; + + template + bool operator()(bool a, T b) { + return a & static_cast(b); + } + + bool simd_scan(bool x) { + for (int i = 1; i <= 16; i *= 2) { + bool other = simd_shuffle_and_fill_up(x, init, i); + x &= other; + } + return x; + } + + bool simd_exclusive_scan(bool x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMax { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return (a >= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = (x >= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumMin { + static constexpr constant U init = Limits::max; + + template + U operator()(U a, T b) { + return (a <= b) ? a : b; + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = (x <= other) ? x : other; + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +struct CumLogaddexp { + static constexpr constant U init = Limits::min; + + template + U operator()(U a, T b) { + return LogAddExp{}(a, static_cast(b)); + } + + U simd_scan(U x) { + for (int i = 1; i <= 16; i *= 2) { + U other = simd_shuffle_and_fill_up(x, init, i); + x = LogAddExp{}(x, other); + } + return x; + } + + U simd_exclusive_scan(U x) { + x = simd_scan(x); + return simd_shuffle_and_fill_up(x, init, 1); + } +}; + +template +inline void load_unsafe(U values[N_READS], const device T* input) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = input[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = input[i]; + } + } +} + +template +inline void load_safe( + U values[N_READS], + const device T* input, + int start, + int total, + U init) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + values[N_READS - i - 1] = + (start + N_READS - i - 1 < total) ? input[i] : init; + } + } else { + for (int i = 0; i < N_READS; i++) { + values[i] = (start + i < total) ? input[i] : init; + } + } +} + +template +inline void write_unsafe(U values[N_READS], device U* out) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + out[i] = values[N_READS - i - 1]; + } + } else { + for (int i = 0; i < N_READS; i++) { + out[i] = values[i]; + } + } +} + +template +inline void write_safe(U values[N_READS], device U* out, int start, int total) { + if (reverse) { + for (int i = 0; i < N_READS; i++) { + if (start + N_READS - i - 1 < total) { + out[i] = values[N_READS - i - 1]; + } + } + } else { + for (int i = 0; i < N_READS; i++) { + if (start + i < total) { + out[i] = values[i]; + } + } + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void contiguous_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + Op op; + + // Position the pointers + size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; + in += offset; + out += offset; + + // Compute the number of simd_groups + uint simd_groups = lsize.x / simd_size; + + // Allocate memory + U prefix = Op::init; + U values[N_READS]; + threadgroup U simdgroup_sums[32]; + + // Loop over the reduced axis in blocks of size ceildiv(axis_size, + // N_READS*lsize) + // Read block + // Compute inclusive scan of the block + // Compute inclusive scan per thread + // Compute exclusive scan of thread sums in simdgroup + // Write simdgroup sums in SM + // Compute exclusive scan of simdgroup sums + // Compute the output by scanning prefix, prev_simdgroup, prev_thread, + // value + // Write block + + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { + // Compute the block offset + uint offset = r * lsize.x * N_READS + lid.x * N_READS; + + // Read the values + if (reverse) { + if ((offset + N_READS) < axis_size) { + load_unsafe( + values, in + axis_size - offset - N_READS); + } else { + load_safe( + values, + in + axis_size - offset - N_READS, + offset, + axis_size, + Op::init); + } + } else { + if ((offset + N_READS) < axis_size) { + load_unsafe(values, in + offset); + } else { + load_safe( + values, in + offset, offset, axis_size, Op::init); + } + } + + // Compute an inclusive scan per thread + for (int i = 1; i < N_READS; i++) { + values[i] = op(values[i], values[i - 1]); + } + + // Compute exclusive scan of thread sums + U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); + + // Write simdgroup_sums to SM + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_lane_id == simd_size - 1) { + simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute exclusive scan of simdgroup_sums + if (simd_group_id == 0) { + U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); + simdgroup_sums[simd_lane_id] = prev_simdgroup; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Compute the output + for (int i = 0; i < N_READS; i++) { + values[i] = op(values[i], prefix); + values[i] = op(values[i], simdgroup_sums[simd_group_id]); + values[i] = op(values[i], prev_thread); + } + + // Write the values + if (reverse) { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe( + values, out + axis_size - offset - N_READS); + } else { + write_safe( + values, out + axis_size - offset - N_READS, offset, axis_size); + } + } else { + if (lid.x == 0 && offset == 0) { + out[axis_size - 1] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe( + values, out + axis_size - offset - 1 - N_READS); + } else { + write_safe( + values, + out + axis_size - offset - 1 - N_READS, + offset + 1, + axis_size); + } + } + } else { + if (inclusive) { + if ((offset + N_READS) < axis_size) { + write_unsafe(values, out + offset); + } else { + write_safe( + values, out + offset, offset, axis_size); + } + } else { + if (lid.x == 0 && offset == 0) { + out[0] = Op::init; + } + if ((offset + N_READS + 1) < axis_size) { + write_unsafe(values, out + offset + 1); + } else { + write_safe( + values, out + offset + 1, offset + 1, axis_size); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Share the prefix + if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { + simdgroup_sums[0] = values[N_READS - 1]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + prefix = simdgroup_sums[0]; + } +} + +template < + typename T, + typename U, + typename Op, + int N_READS, + bool inclusive, + bool reverse> +[[kernel]] void strided_scan( + const device T* in [[buffer(0)]], + device U* out [[buffer(1)]], + const constant size_t& axis_size [[buffer(2)]], + const constant size_t& stride [[buffer(3)]], + const constant size_t& stride_blocks [[buffer(4)]], + uint3 gid [[threadgroup_position_in_grid]], + uint3 gsize [[threadgroups_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int simd_size = 32; + constexpr int BM = 32; + constexpr int BN = 32; + constexpr int BN_pad = 32 + 16 / sizeof(U); + constexpr int n_simds = BN / N_READS; + constexpr int n_scans = BN / n_simds; + Op op; + + threadgroup U read_buffer[BM * BN_pad]; + U values[n_scans]; + U prefix[n_scans]; + for (int i = 0; i < n_scans; i++) { + prefix[i] = Op::init; + } + + // Compute offsets + size_t full_gid = gid.y + gsize.y * size_t(gid.z); + size_t offset = full_gid / stride_blocks * axis_size * stride; + size_t global_index_x = full_gid % stride_blocks * BN; + uint read_offset_y = (lid.x * N_READS) / BN; + uint read_offset_x = (lid.x * N_READS) % BN; + uint scan_offset_y = simd_lane_id; + uint scan_offset_x = simd_group_id * n_scans; + + uint stride_limit = stride - global_index_x; + in += offset + global_index_x + read_offset_x; + out += offset + global_index_x + read_offset_x; + threadgroup U* read_into = + read_buffer + read_offset_y * BN_pad + read_offset_x; + threadgroup U* read_from = + read_buffer + scan_offset_y * BN_pad + scan_offset_x; + + for (uint j = 0; j < axis_size; j += BM) { + // Calculate the indices for the current thread + uint index_y = j + read_offset_y; + uint check_index_y = index_y; + if (reverse) { + index_y = axis_size - 1 - index_y; + } + + // Read in SM + threadgroup_barrier(mem_flags::mem_threadgroup); + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + read_into[i] = in[index_y * stride + i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + read_into[i] = in[index_y * stride + i]; + } else { + read_into[i] = Op::init; + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Read strided into registers + for (int i = 0; i < n_scans; i++) { + values[i] = read_from[i]; + } + simdgroup_barrier(mem_flags::mem_threadgroup); + + // Perform the scan + for (int i = 0; i < n_scans; i++) { + values[i] = op.simd_scan(values[i]); + values[i] = op(values[i], prefix[i]); + prefix[i] = simd_shuffle(values[i], simd_size - 1); + } + + // Write to SM + for (int i = 0; i < n_scans; i++) { + read_from[i] = values[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write to device memory + if (!inclusive) { + if (check_index_y == 0) { + if ((read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + out[index_y * stride + i] = Op::init; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = Op::init; + } + } + } + } + if (reverse) { + index_y -= 1; + check_index_y += 1; + } else { + index_y += 1; + check_index_y += 1; + } + } + if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { + for (int i = 0; i < N_READS; i++) { + out[index_y * stride + i] = read_into[i]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { + out[index_y * stride + i] = read_into[i]; + } + } + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/sdpa_vector.h b/dist/include/mlx/backend/metal/kernels/sdpa_vector.h new file mode 100644 index 0000000..96d22d8 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/sdpa_vector.h @@ -0,0 +1,415 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +constant bool has_mask [[function_constant(20)]]; +constant bool query_transposed [[function_constant(21)]]; +constant bool do_causal [[function_constant(22)]]; +constant bool bool_mask [[function_constant(23)]]; +constant bool float_mask [[function_constant(24)]]; +constant bool has_sinks [[function_constant(25)]]; + +template +[[kernel]] void sdpa_vector( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device T* out [[buffer(3)]], + const constant int& gqa_factor [[buffer(4)]], + const constant int& N [[buffer(5)]], + const constant size_t& k_head_stride [[buffer(6)]], + const constant size_t& k_seq_stride [[buffer(7)]], + const constant size_t& v_head_stride [[buffer(8)]], + const constant size_t& v_seq_stride [[buffer(9)]], + const constant float& scale [[buffer(10)]], + const device bool* bmask [[buffer(11), function_constant(bool_mask)]], + const device T* fmask [[buffer(12), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(13), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(14), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(15), function_constant(has_mask)]], + const device T* sinks [[buffer(16), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(17), function_constant(has_sinks)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + int inner_k_stride = BN * int(k_seq_stride); + int inner_v_stride = BN * int(v_seq_stride); + + typedef float U; + + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int q_batch_head_idx = tid.x; + const int q_seq_idx = tid.y; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int q_offset = + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + queries += q_offset * D + simd_lid * qk_per_thread; + keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + + simd_lid * qk_per_thread; + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + + out += o_offset * V + simd_gid * v_per_thread; + + // Read the query and 0 the output accumulator + for (int i = 0; i < qk_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = Limits::finite_min; + U sum_exp_score = 0; + if (has_sinks && simd_gid == 0) { + max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); + sum_exp_score = 1; + } + + // For each key + for (int i = simd_gid; i < N; i += BN) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); + } + if (use_key) { + // Read the key + for (int j = 0; j < qk_per_thread; j++) { + k[j] = keys[j]; + } + + // Compute the i-th score + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + if (float_mask) { + score += static_cast(fmask[0]); + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * values[j]; + } + } + + // Move the pointers to the next kv + keys += inner_k_stride; + values += inner_v_stride; + if (bool_mask) { + bmask += BN * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * mask_kv_seq_stride; + } + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + // Now we need to aggregate all the outputs + for (int i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < v_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} + +template +[[kernel]] void sdpa_vector_2pass_1( + const device T* queries [[buffer(0)]], + const device T* keys [[buffer(1)]], + const device T* values [[buffer(2)]], + device float* out [[buffer(3)]], + device float* sums [[buffer(4)]], + device float* maxs [[buffer(5)]], + const constant int& gqa_factor [[buffer(6)]], + const constant int& N [[buffer(7)]], + const constant size_t& k_head_stride [[buffer(8)]], + const constant size_t& k_seq_stride [[buffer(9)]], + const constant size_t& v_head_stride [[buffer(10)]], + const constant size_t& v_seq_stride [[buffer(11)]], + const constant float& scale [[buffer(12)]], + const device bool* bmask [[buffer(13), function_constant(bool_mask)]], + const device T* fmask [[buffer(14), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(15), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(16), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(17), function_constant(has_mask)]], + const device T* sinks [[buffer(18), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(19), function_constant(has_sinks)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 8; + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V / BD; + int inner_k_stride = BN * int(k_seq_stride); + int inner_v_stride = BN * int(v_seq_stride); + constexpr int blocks = 32; + + typedef float U; + + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int block_idx = tid.z; + const int q_batch_head_idx = tid.x; + const int q_seq_idx = tid.y; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int q_offset = + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + + queries += q_offset * D + simd_lid * qk_per_thread; + keys += kv_head_idx * k_head_stride + + (block_idx * BN + simd_gid) * k_seq_stride + simd_lid * qk_per_thread; + values += kv_head_idx * v_head_stride + + (block_idx * BN + simd_gid) * v_seq_stride + simd_lid * v_per_thread; + out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + (block_idx * BN + simd_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + (block_idx * BN + simd_gid) * mask_kv_seq_stride + + q_seq_idx * mask_q_seq_stride; + } + sums += o_offset * blocks + block_idx; + maxs += o_offset * blocks + block_idx; + + // Read the query and 0 the output accumulator + for (int i = 0; i < qk_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = Limits::finite_min; + U sum_exp_score = 0; + if (has_sinks && block_idx == 0 && simd_gid == 0) { + int q_head_idx = q_batch_head_idx % num_q_heads; + max_score = static_cast(sinks[q_head_idx]); + sum_exp_score = 1; + } + + // For each key + for (int i = block_idx * BN + simd_gid; i < N; i += blocks * BN) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= Limits::finite_min); + } + if (use_key) { + // Read the key + for (int i = 0; i < qk_per_thread; i++) { + k[i] = keys[i]; + } + + // Compute the i-th score + U score = 0; + for (int i = 0; i < qk_per_thread; i++) { + score += q[i] * k[i]; + } + score = simd_sum(score); + + if (float_mask) { + score += fmask[0]; + } + + // Update the accumulators + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update the output accumulator + for (int i = 0; i < v_per_thread; i++) { + o[i] = o[i] * factor + exp_score * values[i]; + } + } + + // Move the pointers to the next kv + keys += blocks * inner_k_stride; + values += blocks * inner_v_stride; + if (bool_mask) { + bmask += BN * blocks * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * blocks * mask_kv_seq_stride; + } + } + + // Each thread has a partial part of the output so we need to combine them. + + // First let's communicate the max and sum_exp + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = (simd_lid < BN) ? max_scores[simd_lid] : -1e9; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = (simd_lid < BN) ? sum_exp_scores[simd_lid] : 0; + sum_exp_score = simd_sum(sum_exp_score * factor); + + // Write the sum and new max + if (simd_gid == 0) { + sums[0] = sum_exp_score; + maxs[0] = new_max; + } + + // Now we need to aggregate all the outputs + for (int i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BN + simd_gid] = + o[i] * fast::exp(max_scores[simd_gid] - new_max); + threadgroup_barrier(mem_flags::mem_threadgroup); + + // And write the output + if (simd_gid == 0) { + U output = outputs[simd_lid * BN]; + for (int j = 1; j < BN; j++) { + output += outputs[simd_lid * BN + j]; + } + out[i] = static_cast(output); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } +} + +template +[[kernel]] void sdpa_vector_2pass_2( + const device float* partials [[buffer(0)]], + const device float* sums [[buffer(1)]], + const device float* maxs [[buffer(2)]], + device T* out [[buffer(3)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D / BD; + constexpr int blocks = 32; + + typedef float U; + + thread U o[elem_per_thread]; + threadgroup U outputs[BN * BD]; + + // Adjust positions + const int head_idx = tid.x; + const int q_seq_idx = tid.y; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; + partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; + sums += q_offset * blocks; + maxs += q_offset * blocks; + out += q_offset * D + simd_gid * elem_per_thread; + + // First everybody reads the max and sum_exp + U max_score = maxs[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + U sum_exp_score = simd_sum(sums[simd_lid] * factor); + + // Now read the block into registers and then use shared memory to transpose + // it + for (int i = 0; i < elem_per_thread; i++) { + o[i] = partials[i]; + } + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // And write the output + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/softmax.h b/dist/include/mlx/backend/metal/kernels/softmax.h new file mode 100644 index 0000000..6ea4ac7 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/softmax.h @@ -0,0 +1,190 @@ +// Copyright © 2023-2024 Apple Inc. + +template +inline T softmax_exp(T x) { + // Softmax doesn't need high precision exponential cause x is gonna be in + // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). + return fast::exp(x); +} + +template +[[kernel]] void softmax_single_row( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint _lid [[thread_position_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + int lid = _lid; + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + AccT ld[N_READS]; + + in += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + ld[i] = AccT(in[i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + ld[i] = + ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; + } + } + if (simd_group_id == 0) { + local_max[simd_lane_id] = Limits::min; + local_normalizer[simd_lane_id] = 0; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Get the max + AccT maxval = Limits::finite_min; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < ld[i]) ? ld[i] : maxval; + } + maxval = simd_max(maxval); + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + maxval = simd_max(local_max[simd_lane_id]); + if (simd_lane_id == 0) { + local_max[0] = maxval; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = local_max[0]; + + // Compute exp(x_i - maxval) and store the partial sums in local_normalizer + AccT normalizer = 0; + for (int i = 0; i < N_READS; i++) { + AccT exp_x = softmax_exp(ld[i] - maxval); + ld[i] = exp_x; + normalizer += exp_x; + } + normalizer = simd_sum(normalizer); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (simd_group_id == 0) { + normalizer = simd_sum(local_normalizer[simd_lane_id]); + if (simd_lane_id == 0) { + local_normalizer[0] = normalizer; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = 1 / local_normalizer[0]; + + // Normalize and write to the output + out += gid * size_t(axis_size) + lid * N_READS; + if (lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[i] = T(ld[i] * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((lid * N_READS + i) < axis_size) { + out[i] = T(ld[i] * normalizer); + } + } + } +} + +template +[[kernel]] void softmax_looped( + const device T* in, + device T* out, + constant int& axis_size, + uint gid [[threadgroup_position_in_grid]], + uint lid [[thread_position_in_threadgroup]], + uint lsize [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + in += gid * size_t(axis_size); + + constexpr int SIMD_SIZE = 32; + + threadgroup AccT local_max[SIMD_SIZE]; + threadgroup AccT local_normalizer[SIMD_SIZE]; + + // Get the max and the normalizer in one go + AccT prevmax; + AccT maxval = Limits::finite_min; + AccT normalizer = 0; + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + AccT vals[N_READS]; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + vals[i] = AccT(in[offset + i]); + } + } else { + for (int i = 0; i < N_READS; i++) { + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; + } + } + prevmax = maxval; + for (int i = 0; i < N_READS; i++) { + maxval = (maxval < vals[i]) ? vals[i] : maxval; + } + normalizer *= softmax_exp(prevmax - maxval); + for (int i = 0; i < N_READS; i++) { + normalizer += softmax_exp(vals[i] - maxval); + } + } + // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * + // lsize) parts. We need to combine them. + // 1. We start by finding the max across simd groups + // 2. We then change the partial normalizers to account for a possible + // change in max + // 3. We sum all normalizers + prevmax = maxval; + maxval = simd_max(maxval); + normalizer *= softmax_exp(prevmax - maxval); + normalizer = simd_sum(normalizer); + + // Now the normalizer and max value is correct for each simdgroup. We write + // them shared memory and combine them. + prevmax = maxval; + if (simd_lane_id == 0) { + local_max[simd_group_id] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = simd_max(local_max[simd_lane_id]); + normalizer *= softmax_exp(prevmax - maxval); + if (simd_lane_id == 0) { + local_normalizer[simd_group_id] = normalizer; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + normalizer = simd_sum(local_normalizer[simd_lane_id]); + normalizer = 1 / normalizer; + + // Finally given the normalizer and max value we can directly write the + // softmax output + out += gid * size_t(axis_size); + for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); + r++) { + int offset = r * lsize * N_READS + lid * N_READS; + if (offset + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } else { + for (int i = 0; i < N_READS; i++) { + if (offset + i < axis_size) { + out[offset + i] = + T(softmax_exp(in[offset + i] - maxval) * normalizer); + } + } + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/sort.h b/dist/include/mlx/backend/metal/kernels/sort.h new file mode 100644 index 0000000..c439287 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/sort.h @@ -0,0 +1,715 @@ +// Copyright © 2023-2024 Apple Inc. + +#define MLX_MTL_CONST static constant constexpr const +#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") + +using namespace metal; + +// Based on GPU merge sort algorithm at +// https://github.com/NVIDIA/cccl/tree/main/cub/cub + +/////////////////////////////////////////////////////////////////////////////// +// Thread-level sort +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC void thread_swap(thread T& a, thread T& b) { + T w = a; + a = b; + b = w; +} + +template +struct Init { + static constexpr constant T v = Limits::max; +}; + +template +struct Init>> { + static constexpr constant T v = metal::numeric_limits::quiet_NaN(); +}; + +template +struct LessThan { + static constexpr constant T init = Init::v; + METAL_FUNC bool operator()(T a, T b) const { + if constexpr ( + metal::is_floating_point_v || metal::is_same_v) { + bool an = isnan(a); + bool bn = isnan(b); + if (an | bn) { + return (!an) & bn; + } + } + return a < b; + } +}; + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short N_PER_THREAD, + typename CompareOp> +struct ThreadSort { + static METAL_FUNC void sort( + thread ValT (&vals)[N_PER_THREAD], + thread IdxT (&idxs)[N_PER_THREAD]) { + CompareOp op; + MLX_MTL_LOOP_UNROLL + for (short i = 0; i < N_PER_THREAD; ++i) { + MLX_MTL_LOOP_UNROLL + for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { + if (op(vals[j + 1], vals[j])) { + thread_swap(vals[j + 1], vals[j]); + if (ARG_SORT) { + thread_swap(idxs[j + 1], idxs[j]); + } + } + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Threadgroup-level sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp> +struct BlockMergeSort { + using thread_sort_t = + ThreadSort; + static METAL_FUNC int merge_partition( + const threadgroup ValT* As, + const threadgroup ValT* Bs, + short A_sz, + short B_sz, + short sort_md) { + CompareOp op; + + short A_st = max(0, sort_md - B_sz); + short A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + short md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } + + static METAL_FUNC void merge_step( + const threadgroup ValT* As, + const threadgroup ValT* Bs, + const threadgroup IdxT* As_idx, + const threadgroup IdxT* Bs_idx, + short A_sz, + short B_sz, + thread ValT (&vals)[N_PER_THREAD], + thread IdxT (&idxs)[N_PER_THREAD]) { + CompareOp op; + short a_idx = 0; + short b_idx = 0; + + for (int i = 0; i < N_PER_THREAD; ++i) { + auto a = As[a_idx]; + auto b = Bs[b_idx]; + bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); + + vals[i] = pred ? b : a; + if (ARG_SORT) { + idxs[i] = pred ? Bs_idx[b_idx] : As_idx[a_idx]; + } + + b_idx += short(pred); + a_idx += short(!pred); + } + } + + static METAL_FUNC void sort( + threadgroup ValT* tgp_vals [[threadgroup(0)]], + threadgroup IdxT* tgp_idxs [[threadgroup(1)]], + int size_sorted_axis, + uint3 lid [[thread_position_in_threadgroup]]) { + // Get thread location + int idx = lid.x * N_PER_THREAD; + + // Load from shared memory + thread ValT thread_vals[N_PER_THREAD]; + thread IdxT thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; ++i) { + thread_vals[i] = tgp_vals[idx + i]; + if (ARG_SORT) { + thread_idxs[i] = tgp_idxs[idx + i]; + } + } + + // Per thread sort + if (idx < size_sorted_axis) { + thread_sort_t::sort(thread_vals, thread_idxs); + } + + // Do merges using threadgroup memory + for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; + merge_threads *= 2) { + // Update threadgroup memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Find location in merge step + int merge_group = lid.x / merge_threads; + int merge_lane = lid.x % merge_threads; + + int sort_sz = N_PER_THREAD * merge_threads; + int sort_st = N_PER_THREAD * merge_threads * merge_group; + + // As = tgp_vals[A_st:A_ed] is sorted + // Bs = tgp_vals[B_st:B_ed] is sorted + int A_st = sort_st; + int A_ed = sort_st + sort_sz / 2; + int B_st = sort_st + sort_sz / 2; + int B_ed = sort_st + sort_sz; + + const threadgroup ValT* As = tgp_vals + A_st; + const threadgroup ValT* Bs = tgp_vals + B_st; + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Find a partition of merge elements + // Ci = merge(As[partition:], Bs[sort_md - partition:]) + // of size N_PER_THREAD for each merge lane i + // C = [Ci] is sorted + int sort_md = N_PER_THREAD * merge_lane; + int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); + + As += partition; + Bs += sort_md - partition; + + A_sz -= partition; + B_sz -= sort_md - partition; + + const threadgroup IdxT* As_idx = + ARG_SORT ? tgp_idxs + A_st + partition : nullptr; + const threadgroup IdxT* Bs_idx = + ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; + + // Merge starting at the partition and store results in thread registers + merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); + } + + // Write out to shared memory + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + tgp_vals[idx + i] = thread_vals[i]; + if (ARG_SORT) { + tgp_idxs[idx + i] = thread_idxs[i]; + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Kernel sort +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMergeSort { + using ValT = T; + using IdxT = uint; + using block_merge_sort_t = BlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device T* inp, + device U* out, + const constant int& size_sorted_axis, + const constant int& in_stride_sorted_axis, + const constant int& out_stride_sorted_axis, + const constant int& in_stride_segment_axis, + const constant int& out_stride_segment_axis, + threadgroup ValT* tgp_vals, + threadgroup IdxT* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + inp += tid.y * in_stride_segment_axis; + out += tid.y * out_stride_segment_axis; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] + : ValT(CompareOp::init); + if (ARG_SORT) { + tgp_idxs[i] = i; + } + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { + if (ARG_SORT) { + out[i * out_stride_sorted_axis] = tgp_idxs[i]; + } else { + out[i * out_stride_sorted_axis] = tgp_vals[i]; + } + } + } +}; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& in_stride_segment_axis [[buffer(5)]], + const constant int& out_stride_segment_axis [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using ValT = typename sort_kernel::ValT; + using IdxT = typename sort_kernel::IdxT; + + if (ARG_SORT) { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + in_stride_segment_axis, + out_stride_segment_axis, + tgp_vals, + nullptr, + tid, + lid); + } +} + +constant constexpr const int zero_helper = 0; + +template < + typename T, + typename U, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( + const device T* inp [[buffer(0)]], + device U* out [[buffer(1)]], + const constant int& size_sorted_axis [[buffer(2)]], + const constant int& in_stride_sorted_axis [[buffer(3)]], + const constant int& out_stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant int64_t* in_nc_strides [[buffer(7)]], + const constant int64_t* out_nc_strides [[buffer(8)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = + KernelMergeSort; + using ValT = typename sort_kernel::ValT; + using IdxT = typename sort_kernel::IdxT; + + auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); + auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); + inp += in_block_idx; + out += out_block_idx; + + if (ARG_SORT) { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + tgp_idxs, + tid, + lid); + } else { + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + sort_kernel::block_sort( + inp, + out, + size_sorted_axis, + in_stride_sorted_axis, + out_stride_sorted_axis, + zero_helper, + zero_helper, + tgp_vals, + nullptr, + tid, + lid); + } +} + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +struct KernelMultiBlockMergeSort { + using block_merge_sort_t = BlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; + + static METAL_FUNC void block_sort( + const device ValT* inp, + device ValT* out_vals, + device IdxT* out_idxs, + const constant int& size_sorted_axis, + const constant int& stride_sorted_axis, + threadgroup ValT* tgp_vals, + threadgroup IdxT* tgp_idxs, + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // tid.y tells us the segment index + int base_idx = tid.x * N_PER_BLOCK; + + // Copy into threadgroup memory + for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] + : ValT(CompareOp::init); + tgp_idxs[i] = idx; + } + + // Sort elements within the block + threadgroup_barrier(mem_flags::mem_threadgroup); + + block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Write output + for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + out_vals[idx] = tgp_vals[i]; + out_idxs[idx] = tgp_idxs[i]; + } + } + } + + static METAL_FUNC int merge_partition( + const device ValT* As, + const device ValT* Bs, + int A_sz, + int B_sz, + int sort_md) { + CompareOp op; + + int A_st = max(0, sort_md - B_sz); + int A_ed = min(sort_md, A_sz); + + while (A_st < A_ed) { + int md = A_st + (A_ed - A_st) / 2; + auto a = As[md]; + auto b = Bs[sort_md - 1 - md]; + + if (op(b, a)) { + A_ed = md; + } else { + A_st = md + 1; + } + } + + return A_ed; + } +}; + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( + const device ValT* inp [[buffer(0)]], + device ValT* out_vals [[buffer(1)]], + device IdxT* out_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& stride_sorted_axis [[buffer(4)]], + const constant int& nc_dim [[buffer(5)]], + const constant int* nc_shape [[buffer(6)]], + const constant int64_t* nc_strides [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); + inp += block_idx; + out_vals += tid.y * size_sorted_axis; + out_idxs += tid.y * size_sorted_axis; + + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + + sort_kernel::block_sort( + inp, + out_vals, + out_idxs, + size_sorted_axis, + stride_sorted_axis, + tgp_vals, + tgp_idxs, + tid, + lid); +} + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD> +[[kernel]] void mb_block_partition( + device IdxT* block_partitions [[buffer(0)]], + const device ValT* dev_vals [[buffer(1)]], + const device IdxT* dev_idxs [[buffer(2)]], + const constant int& size_sorted_axis [[buffer(3)]], + const constant int& merge_tiles [[buffer(4)]], + const constant int& n_blocks [[buffer(5)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 tgp_dims [[threads_per_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD>; + + block_partitions += tid.y * tgp_dims.x; + dev_vals += tid.y * size_sorted_axis; + dev_idxs += tid.y * size_sorted_axis; + + for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { + // Find location in merge step + int merge_group = i / merge_tiles; + int merge_lane = i % merge_tiles; + + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + + int A_st = min(size_sorted_axis, sort_st); + int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + int B_st = A_ed; + int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); + + int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); + int partition = sort_kernel::merge_partition( + dev_vals + A_st, + dev_vals + B_st, + A_ed - A_st, + B_ed - B_st, + partition_at); + + block_partitions[i] = A_st + partition; + } +} + +template < + typename ValT, + typename IdxT, + bool ARG_SORT, + short BLOCK_THREADS, + short N_PER_THREAD, + typename CompareOp = LessThan> +[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void +mb_block_merge( + const device IdxT* block_partitions [[buffer(0)]], + const device ValT* dev_vals_in [[buffer(1)]], + const device IdxT* dev_idxs_in [[buffer(2)]], + device ValT* dev_vals_out [[buffer(3)]], + device IdxT* dev_idxs_out [[buffer(4)]], + const constant int& size_sorted_axis [[buffer(5)]], + const constant int& merge_tiles [[buffer(6)]], + const constant int& num_tiles [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + using sort_kernel = KernelMultiBlockMergeSort< + ValT, + IdxT, + ARG_SORT, + BLOCK_THREADS, + N_PER_THREAD, + CompareOp>; + + using block_sort_t = typename sort_kernel::block_merge_sort_t; + + block_partitions += tid.y * (num_tiles + 1); + dev_vals_in += tid.y * size_sorted_axis; + dev_idxs_in += tid.y * size_sorted_axis; + dev_vals_out += tid.y * size_sorted_axis; + dev_idxs_out += tid.y * size_sorted_axis; + + int block_idx = tid.x; + int merge_group = block_idx / merge_tiles; + int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; + int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; + int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; + + int A_st = block_partitions[block_idx + 0]; + int A_ed = block_partitions[block_idx + 1]; + int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); + int B_ed = min( + size_sorted_axis, + 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); + + if ((block_idx % merge_tiles) == merge_tiles - 1) { + A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); + B_ed = min(size_sorted_axis, sort_st + sort_sz); + } + + int A_sz = A_ed - A_st; + int B_sz = B_ed - B_st; + + // Load from global memory + thread ValT thread_vals[N_PER_THREAD]; + thread IdxT thread_idxs[N_PER_THREAD]; + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + if (idx < (A_sz + B_sz)) { + thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] + : dev_vals_in[B_st + idx - A_sz]; + thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] + : dev_idxs_in[B_st + idx - A_sz]; + } else { + thread_vals[i] = CompareOp::init; + thread_idxs[i] = 0; + } + } + + // Write to shared memory + threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; + threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; i++) { + int idx = BLOCK_THREADS * i + lid.x; + tgp_vals[idx] = thread_vals[i]; + tgp_idxs[idx] = thread_idxs[i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Merge + int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); + + int A_st_local = block_sort_t::merge_partition( + tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); + int A_ed_local = A_sz; + + int B_st_local = sort_md_local - A_st_local; + int B_ed_local = B_sz; + + int A_sz_local = A_ed_local - A_st_local; + int B_sz_local = B_ed_local - B_st_local; + + // Do merge + block_sort_t::merge_step( + tgp_vals + A_st_local, + tgp_vals + A_ed_local + B_st_local, + tgp_idxs + A_st_local, + tgp_idxs + A_ed_local + B_st_local, + A_sz_local, + B_sz_local, + thread_vals, + thread_idxs); + + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N_PER_THREAD; ++i) { + int idx = lid.x * N_PER_THREAD; + tgp_vals[idx + i] = thread_vals[i]; + tgp_idxs[idx + i] = thread_idxs[i]; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + // Write output + int base_idx = tid.x * sort_kernel::N_PER_BLOCK; + for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { + int idx = base_idx + i; + if (idx < size_sorted_axis) { + dev_vals_out[idx] = tgp_vals[i]; + dev_idxs_out[idx] = tgp_idxs[i]; + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/attn.h b/dist/include/mlx/backend/metal/kernels/steel/attn/attn.h new file mode 100644 index 0000000..991d4d6 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/attn.h @@ -0,0 +1,296 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/attn/loader.h" +#include "mlx/backend/metal/kernels/steel/attn/mma.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h b/dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h new file mode 100644 index 0000000..4de11b0 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h @@ -0,0 +1,476 @@ +// Copyright © 2024-25 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/attn/attn.h" + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Sequence + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Sequence + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + // Prepare threadgroup memory + constexpr short padQ = 16 / sizeof(T); + constexpr short padK = 16 / sizeof(T); + constexpr short padV = 16 / sizeof(T); + + constexpr short LDQ_tgp = BD + padQ; + constexpr short LDK_tgp = BK + padK; + constexpr short LDV_tgp = BD + padV; + + constexpr short tgp_mem_0 = (BK + padK) * (BD); + constexpr short tgp_mem_1 = BK * (BD + padV); + constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; + + threadgroup T Q_smem[BQ * (BD + padQ)]; + threadgroup T KV_smem[tgp_mem_s]; + + threadgroup T* Qs = Q_smem; + threadgroup T* Ks = KV_smem; + threadgroup T* Vs = KV_smem; + + // Prepare block loaders + using QBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BQ, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDQ_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 1, + /* short tgp_size = */ WM * WN * 32>; + + // K is loaded in transposed + using KBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ 1, + /* short kDstStrCol = */ LDK_tgp, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + using VBlockLoader = BlockLoaderT< + /* typename T = */ T, + /* short BROWS = */ BK, + /* short BCOLS = */ BD, + /* short kDstStrRow = */ LDV_tgp, + /* short kDstStrCol = */ 1, + /* short reduction_dim = */ 0, + /* short tgp_size = */ WM * WN * 32>; + + QBlockLoader loader_q( + Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); + KBlockLoader loader_k( + K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); + VBlockLoader loader_v( + V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); + + TransformScale ts(static_cast(params->scale * M_LOG2E_F)); + + // Prepare MMA tiles + constexpr short kFragSize = 8; // MMAFrag size + using MMAFrag_acc_t = BaseMMAFrag; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * kFragSize); + // KV sequence frags (all warps load the same frags) + constexpr int TK = BK / kFragSize; + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / kFragSize; + + static_assert(TQ == 1, "Check TQ"); + + MMATile Qtile; + MMATile Ktile; + MMATile Stile; + MMATile Vtile; + MMATile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = kFragSize * TQ * simd_group_id; + + const short Qs_offset = (tm + sm) * LDQ_tgp + sn; + const short Ks_offset = sm * LDK_tgp + sn; + const short Vs_offset = sm * LDV_tgp + sn; + + constexpr short Qs_tile_stride = kFragSize; + constexpr short Ks_tile_stride = kFragSize * LDK_tgp; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load Q blocks apply scale + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + loader_q.load_safe(short2(BD, params->qL_rem)); + } else { + loader_q.load_unsafe(); + } + loader_q.apply_inplace_op(ts); + + // Init row reduction variables + constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; + + AccumType max_score[kRowsPT]; + AccumType sum_score[kRowsPT] = {0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::finite_min; + } + + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + // Load K block and apply scale + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!align_K && kb == (params->NK_aligned)) { + loader_k.load_safe(short2(BD, params->kL_rem)); + } else { + loader_k.load_unsafe(); + } + + // Do S = Q @ K.T + Stile.clear(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short dd = 0; dd < TD; dd++) { + simdgroup_barrier(mem_flags::mem_none); + + Qtile.template load( + &Qs[Qs_offset + dd * Qs_tile_stride]); + Ktile.template load( + &Ks[Ks_offset + dd * Ks_tile_stride]); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Stile, Qtile, Ktile, Stile); + } + + // Mask out length sequence + if (!align_K && kb == (params->NK_aligned)) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + short col_pos = sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if ((col_pos + jj) >= params->kL_rem) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // Mask out if causal + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = + tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { + if (row_pos < (col_pos + jj)) { + Stile.frag_at(i, j)[jj] = neg_inf; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + using stile_t = decltype(Stile); + using selem_t = typename stile_t::elem_type; + constexpr auto neg_inf = Limits::finite_min; + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + + using MMAFrag_mask_t = BaseMMAFrag; + using frag_t = typename MMAFrag_mask_t::frag_type; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < stile_t::kTileRows; i++) { + const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); + STEEL_PRAGMA_UNROLL + for (short j = 0; j < stile_t::kTileCols; j++) { + const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); + + frag_t mfrag; + + MMAFrag_mask_t::load_safe( + mfrag, + mask, + int64_t(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + Stile.frag_at(i, j)[jj] = + mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; + } else { + Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]); + } + } + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load V blocks + if (!align_K && kb == (params->NK_aligned)) { + loader_v.load_safe(short2(BD, params->kL_rem)); + } else { + loader_v.load_unsafe(); + } + + // Do softmax + + // Temp variables + AccumType new_max[kRowsPT]; + AccumType factor[kRowsPT]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Stile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Stile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + } + + // Save max for next iteration + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = new_max[i]; + } + + // Row Sum + AccumType sum_score_tmp[kRowsPT] = {0}; + Stile.template row_reduce(sum_score_tmp); + + // Update norm + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; + } + + // Update O + Otile.template row_bin_op(factor); + + // Load V into registers + threadgroup_barrier(mem_flags::mem_threadgroup); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } + + const short kk = ik * kFragSize; + const short dd = id * kFragSize; + + Vtile.template load( + &Vs[Vs_offset + kk * LDV_tgp + dd]); + + if constexpr (BD == 128) { + simdgroup_barrier(mem_flags::mem_none); + } + + MMAFrag_acc_t::mma( + Otile.frag_at(iq, id), + Stile.frag_at(iq, ik), + Vtile.frag_at(0, 0), + Otile.frag_at(iq, id)); + } + } + } + + // Prepare for next iteration + loader_k.next(); + loader_v.next(); + } + + // Normalize output + Otile.template row_bin_op(sum_score); + threadgroup_barrier(mem_flags::mem_none); + + // Store results + O += (tm + sm) * params->O_strides[2] + sn; + + if (!align_Q && int(tid.x) == (params->NQ_aligned)) { + auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); + } else { + Otile.template store(O, params->O_strides[2]); + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h b/dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h new file mode 100644 index 0000000..1814f9b --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h @@ -0,0 +1,481 @@ +// Copyright © 2024-25 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/attn/nax.h" +#include "mlx/backend/metal/kernels/steel/attn/params.h" +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool align_Q [[function_constant(200)]]; +constant bool align_K [[function_constant(201)]]; + +constant bool has_mask [[function_constant(300)]]; +constant bool do_causal [[function_constant(301)]]; +constant bool has_sinks [[function_constant(302)]]; + +template +struct TransformScale { + T scale; + METAL_FUNC TransformScale(T scale_) : scale(scale_) {} + + METAL_FUNC T apply(T x) const { + return scale * x; + } +}; + +struct MaxOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return metal::max(x, y); + } +}; + +struct SumOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x + y; + } +}; + +struct MulOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x * y; + } +}; + +struct SubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x - y; + } +}; + +struct ExpSubOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return fast::exp2(x - y); + } +}; + +struct DivOp { + template + METAL_FUNC static constexpr T apply(T x, T y) { + return x / y; + } +}; + +// clang-format off +template < + typename T, + int BQ, + int BK, + int BD, + int WM, + int WN, + typename MaskType = float, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax( + const device T* Q [[buffer(0)]], + const device T* K [[buffer(1)]], + const device T* V [[buffer(2)]], + device T* O [[buffer(3)]], + const constant AttnParams* params [[buffer(4)]], + const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], + const device MaskType* mask [[buffer(6), function_constant(has_mask)]], + const device T* sinks [[buffer(7), function_constant(has_sinks)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + + // Pacifying compiler + (void)lid; + (void)simd_lane_id; + + // Move to correct block + ulong3 tidl{tid.x, tid.y, tid.z}; + + Q += tidl.z * params->Q_strides[0] + // Batch + tidl.y * params->Q_strides[1] + // Head + tidl.x * BQ * params->Q_strides[2]; // Sequence + + ulong kv_head_idx = int(tid.y) / params->gqa_factor; + K += tidl.z * params->K_strides[0] + // Batch + kv_head_idx * params->K_strides[1]; // Head + + V += tidl.z * params->V_strides[0] + // Batch + kv_head_idx * params->V_strides[1]; // Head + + O += tidl.z * params->O_strides[0] + // Batch + tidl.y * params->O_strides[1] + // Head + tidl.x * BQ * params->O_strides[2]; // Sequence + + if (has_mask) { + mask += tidl.z * mask_params->M_strides[0] + // Batch + tidl.y * mask_params->M_strides[1]; // Head + } + + const metal::uniform scale2 = + make_uniform(params->scale) * make_uniform(1.44269504089f); + + // Prepare MMA tiles + constexpr short UQ = 16; + constexpr short UD = 32; + + constexpr int kNWarps = WM * WN; + static_assert( + BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0, + "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); + + // Q seq frags per warp + constexpr int TQ = BQ / (kNWarps * UQ); + // HeadDim frags (all warps load the same frags) + constexpr int TD = BD / UD; + + static_assert(TQ == 1, "Check TQ"); + + using OSubTile = NAXSubTile; + NAXTile Otile; + + Otile.clear(); + + // Prepare mma tile offsets + const short2 simd_coord = OSubTile::NAXFrag_t::get_coord(); + const short sm = simd_coord.y; + const short sn = simd_coord.x; + const short tm = UQ * TQ * simd_group_id; + + Q += (tm + sm) * int(params->Q_strides[2]) + sn; + K += sm * int(params->K_strides[2]) + sn; + V += sm * int(params->V_strides[2]) + sn; + + // Init row reduction variables + constexpr short kRowsPT = decltype(Otile)::kRowsPerThread; + + metal::vec max_score; + metal::vec sum_score{0}; + + // Init to -Inf + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = Limits::finite_min; + } + + if (has_sinks) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); + sum_score[i] = 1; + } + } + + int kb_lim = params->NK; + + if (do_causal) { + int q_max = (tid.x + 1) * BQ + params->qL_off; + kb_lim = (q_max + BK - 1) / BK; + kb_lim = min(params->NK, kb_lim); + } + + const bool is_last_bq = int(tid.x) == (params->NQ_aligned); + // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); + const bool is_last_q = is_last_bq; + + const short lim_rows_q = params->qL_rem - (tm + sm); + const short lim_rows_k = params->kL_rem - sm; + + // Loop over KV seq length + for (int kb = 0; kb < kb_lim; kb++) { + const int is_last_k = (kb == (params->NK_aligned)); + + // Do S = Q @ K.T + constexpr short UDs = 16; + constexpr short UKs = 32; + + constexpr short TDs = BD / UDs; + constexpr short TKs = BK / UKs; + + using SSubTile = NAXSubTile; + using QSubTile = NAXSubTile; + using KSubTile = NAXSubTile; + + NAXTile Stile; + + Stile.clear(); + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TKs; ik++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TDs; id++) { + NAXTile Qtile; + NAXTile Ktile; + + const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs; + const int K_load_off = + ik * UKs * int(params->K_strides[2]) + id * UDs; + + if (!align_Q && is_last_q) { + // Qtile.load_rows( + // Q + Q_load_off, + // int(params->Q_strides[2]), + // lim_rows_q - iq * UQ); + Qtile.load_safe( + Q + Q_load_off, + int(params->Q_strides[2]), + short2(BD, lim_rows_q - iq * UQ)); + } else { + Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); + } + + if (!align_K && is_last_k) { + // Ktile.load_rows( + // K + K_load_off, + // int(params->K_strides[2]), + // lim_rows_k - ik * UKs); + Ktile.load_safe( + K + K_load_off, + int(params->K_strides[2]), + short2(BD, lim_rows_k - ik * UKs)); + } else { + Ktile.load(K + K_load_off, int(params->K_strides[2])); + } + + subtile_matmad_nax( + Stile.subtile_at(iq, ik), + Qtile.subtile_at(0, 0), + metal::false_type{}, + Ktile.subtile_at(0, 0), + metal::true_type{}); + } + } + } + + // Scale S + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Stile.elems()[ii] *= float(scale2); + } + + // Scale and Retile S + constexpr short UK = 16; + constexpr short TK = BK / UK; + using PSubTile = NAXSubTile; + + NAXTile Ptile; + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { + Ptile.elems()[ii] = Stile.elems()[ii]; + } + + // Mask out length sequence + if (!align_K && is_last_k) { + constexpr auto neg_inf = Limits::finite_min; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short col_pos = sn + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Mask out if causal + if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + params->qL_off + tm; + const int base_col = kb * BK; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ; + const short col_pos = base_col + ik * UK; + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { + const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm; + const auto c = col_pos + jj + sn; + const auto loc = ii * PSubTile::kFragThrCols + jj; + fg[loc] = (r < c) ? neg_inf : fg[loc]; + } + } + } + } + } + + // Other masking as needed + if (has_mask) { + constexpr auto neg_inf = Limits::finite_min; + + const int base_row = tid.x * BQ + tm; + const int base_col = kb * BK; + + constexpr bool is_bool = is_same_v; + using melem_t = typename metal::conditional_t; + using MSubTile = NAXSubTile; + + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + const short row_pos = base_row + iq * UQ + sm; + const short col_pos = base_col + ik * UK + sn; + + MSubTile mfrag; + mfrag.load_safe( + mask, + int64_t(mask_params->M_strides[2]), + Int<1>{}, + params->qL, + params->kL, + row_pos, + col_pos); + + thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); + + STEEL_PRAGMA_UNROLL + for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) { + if constexpr (is_bool) { + fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf; + } else { + fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]); + } + } + } + } + } + + // Do softmax + + // Temp variables + metal::vec new_max; + metal::vec factor; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + new_max[i] = max_score[i]; + } + + // Row max + Ptile.template row_reduce(new_max); + + // exp(Si - rowmax(Si)) + Ptile.template row_bin_op(new_max); + + // Factor exp(rowmax(Si) - rowmax(Si-1)) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + factor[i] = fast::exp2(max_score[i] - new_max[i]); + max_score[i] = new_max[i]; + } + + // Row Sum + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + sum_score[i] = sum_score[i] * factor[i]; + } + + Ptile.template row_reduce(sum_score); + + // Update O + Otile.template row_bin_op(factor); + + simdgroup_barrier(mem_flags::mem_none); + + // Do O = P @ V + STEEL_PRAGMA_UNROLL + for (short iq = 0; iq < TQ; iq++) { + STEEL_PRAGMA_UNROLL + for (short id = 0; id < TD; id++) { + if constexpr (BD == 128) { + if (id == 2) { + threadgroup_barrier(mem_flags::mem_none); + } + } + + STEEL_PRAGMA_UNROLL + for (short ik = 0; ik < TK; ik++) { + using VSubTile = NAXSubTile; + NAXTile Vtile; + + const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD; + + if (!align_K && is_last_k) { + // Vtile.load_rows( + // V + V_load_off, + // int(params->V_strides[2]), + // lim_rows_k - ik * UK); + Vtile.load_safe( + V + V_load_off, + int(params->V_strides[2]), + short2(BD, lim_rows_k - ik * UK)); + } else { + Vtile.load(V + V_load_off, int(params->V_strides[2])); + } + + subtile_matmad_nax( + Otile.subtile_at(iq, id), + Ptile.subtile_at(iq, ik), + metal::bool_constant{}, + Vtile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + + // Prepare for next iteration + K += BK * int(params->K_strides[2]); + V += BK * int(params->V_strides[2]); + } + + // Normalize output + + threadgroup_barrier(mem_flags::mem_none); + + metal::vec rcp; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kRowsPT; ++i) { + rcp[i] = (1.f / sum_score[i]); + } + + Otile.template row_bin_op(rcp); + + // Store results + O += (tm + sm) * int(params->O_strides[2]) + sn; + + if (!align_Q && is_last_q) { + if (lim_rows_q <= 0) + return; + + // Otile.store_rows(O, params->O_strides[2], lim_rows_q); + Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q)); + } else { + Otile.store(O, int(params->O_strides[2])); + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/loader.h b/dist/include/mlx/backend/metal/kernels/steel/attn/loader.h new file mode 100644 index 0000000..7ec7981 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/loader.h @@ -0,0 +1,264 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out unneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +template +struct CShape { + STEEL_CONST int kRows = R; + STEEL_CONST int kCols = C; +}; + +template < + typename T, + short BROWS, + short BCOLS, + short kDstStrRow, + short kDstStrCol, + short reduction_dim, + short tgp_size, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoaderT { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + /* Constructor */ + METAL_FUNC BlockLoaderT( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = + op.apply(dst[i * kDstStrRow + j * kDstStrCol]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; + } + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out unneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/mma.h b/dist/include/mlx/backend/metal/kernels/steel/attn/mma.h new file mode 100644 index 0000000..b11a111 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/mma.h @@ -0,0 +1,750 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/attn/transforms.h" +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct Shape2D { + RInt r; + CInt c; + + Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} +}; + +template +struct Layout2D { + Shape shape; + Layout layout; +}; + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + typedef metal::vec row_frag_type; + typedef metal::vec col_frag_type; + + template + using dtype_mat_t = typename metal::simdgroup_matrix; + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + src += off_x * str_x + off_y * str_y; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = static_cast(src[0]); + } else { + dst[i * kElemCols + j] = T(0); + } + src += str_y; + } + src -= kElemCols * str_y; + src += str_x; + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread dtype_frag_t& A, + thread dtype_frag_t& B, + thread dtype_frag_t& C) { + mat_type D_mat; + dtype_mat_t A_mat; + dtype_mat_t B_mat; + dtype_mat_t C_mat; + + reinterpret_cast&>(A_mat.thread_elements()) = A; + reinterpret_cast&>(B_mat.thread_elements()) = B; + reinterpret_cast&>(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + template + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread dtype_mat_t& A, + thread dtype_mat_t& B, + thread dtype_mat_t& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const frag_type& inp_vals, + thread T* reduced_vals) { + T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread frag_type& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags]; // = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_reduce( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template < + typename Dtype, + typename Atype, + typename Btype, + typename Ctype, + int M, + int N, + int K, + class MMAFragD, + class MMAFragA, + class MMAFragB, + class MMAFragC> +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short m_serp = m; //(n % 2) ? (M - 1 - m) : m; + short n_serp = (m % 2) ? (N - 1 - n) : n; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMAFragD::mma( + D.frag_at(m_serp, n_serp), + A.frag_at(m_serp, k), + B.frag_at(k, n_serp), + C.frag_at(m_serp, n_serp)); + } + } + } +} + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / TM_stride; + // Warp tile size along N + STEEL_CONST short TN = BN / TN_stride; + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/nax.h b/dist/include/mlx/backend/metal/kernels/steel/attn/nax.h new file mode 100644 index 0000000..c8f3ea5 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/nax.h @@ -0,0 +1,1076 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = short2{0, 0}; // get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_ = BaseNAXFrag> +struct NAXSubTile { + using NAXFrag_t = NAXFrag_; + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + thread T* vptr = (thread T*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vptr[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + mpp::tensor_ops::matmul2d gemm_op; + + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + gemm_op.run(ct_a, ct_b, ct_c); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/params.h b/dist/include/mlx/backend/metal/kernels/steel/attn/params.h new file mode 100644 index 0000000..f1cf09f --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/params.h @@ -0,0 +1,44 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// Attn param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct AttnParams { + int B; ///< Batch Size + int H; ///< Heads + int D; ///< Head Dim + + int qL; ///< Query Sequence Length + int kL; ///< Key Sequence Length + + int gqa_factor; ///< Group Query factor + float scale; ///< Attention scale + + int NQ; ///< Number of query blocks + int NK; ///< Number of key/value blocks + + int NQ_aligned; ///< Number of full query blocks + int NK_aligned; ///< Number of full key/value blocks + + int qL_rem; ///< Remainder in last query block + int kL_rem; ///< Remainder in last key/value block + int qL_off; ///< Offset in query sequence start + + int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) + int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) + int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) + int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) +}; + +struct AttnMaskParams { + int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/attn/transforms.h b/dist/include/mlx/backend/metal/kernels/steel/attn/transforms.h new file mode 100644 index 0000000..c0624d2 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/attn/transforms.h @@ -0,0 +1,71 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast(x * alpha + (beta * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/conv.h b/dist/include/mlx/backend/metal/kernels/steel/conv/conv.h new file mode 100644 index 0000000..d2e718f --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/conv.h @@ -0,0 +1,13 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/loader.h" +#include "mlx/backend/metal/kernels/steel/conv/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" + +using namespace metal; +using namespace mlx::steel; diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h b/dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h new file mode 100644 index 0000000..6f822c1 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h @@ -0,0 +1,176 @@ +// Copyright © 2024 Apple Inc. + +#include + +using namespace metal; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + int N_CHANNELS = 0, + bool SMALL_FILTER = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + using namespace mlx::steel; + + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + + using loader_a_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DInputBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_a>, + + // Else go to general loader + typename metal::conditional_t< + // Check if filter size is small enough + SMALL_FILTER, + + // Go to small filter specialization + Conv2DInputBlockLoaderSmallFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>, + + // Else go to large filter generalization + Conv2DInputBlockLoaderLargeFilter< + T, + BM, + BN, + BK, + tgp_size, + tgp_padding_a>>>; + + // Weight loader + using loader_b_t = typename metal::conditional_t< + // Check for small channel specialization + N_CHANNELS != 0 && N_CHANNELS <= 4, + + // Go to small channel specialization + Conv2DWeightBlockLoaderSmallChannels< + T, + BM, + BN, + BK, + tgp_size, + N_CHANNELS, + tgp_padding_b>, + + // Else go to general loader + Conv2DWeightBlockLoader>; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + const int N = gemm_params->N; + const int C_per_group = params->C / params->groups; + + // Groups + A += tid.z * C_per_group; + B += tid.z * N * K; + C += tid.z * N; + + B += c_col * K; + C += c_row * (N * params->groups) + c_col; + + const int2 offsets_a(0, c_row); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); + loader_b_t loader_b( + B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + int gemm_k_iterations = gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + short tgp_bm = min(BM, gemm_params->M - c_row); + short tgp_bn = min(BN, gemm_params->N - c_col); + const int ldc = N * params->groups; + mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h b/dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h new file mode 100644 index 0000000..9afebd3 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h @@ -0,0 +1,225 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h" + +constant bool align_C [[function_constant(200)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + typename AccumType = float, + typename Epilogue = TransformNone> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +implicit_gemm_conv_2d_general( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* C [[buffer(2)]], + const constant MLXConvParams<2>* params [[buffer(3)]], + const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], + const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], + const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], + const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + (void)lid; + + constexpr bool transpose_a = false; + constexpr bool transpose_b = true; + constexpr short tgp_padding_a = 16 / sizeof(T); + constexpr short tgp_padding_b = 16 / sizeof(T); + + constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; + constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; + constexpr short shape_a_rows = (transpose_a ? BK : BM); + constexpr short shape_b_rows = (transpose_b ? BN : BK); + constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; + constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; + + constexpr short tgp_size = WM * WN * 32; + + // Input loader + using loader_a_t = + Conv2DInputBlockLoaderGeneral; + + // Weight loader + using loader_b_t = + Conv2DWeightBlockLoaderGeneral; + + using mma_t = BlockMMA< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + shape_a_cols, + shape_b_cols>; + + threadgroup T As[tgp_mem_size_a]; + threadgroup T Bs[tgp_mem_size_b]; + + const int tid_y = ((tid.y) << gemm_params->swizzle_log) + + ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> gemm_params->swizzle_log; + + if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { + return; + } + + const int tid_z = tid.z; + + const int base_oh = tid_z / jump_params->f_out_jump_w; + const int base_ow = tid_z % jump_params->f_out_jump_w; + + const int base_wh = base_h[base_oh].weight_base; + const int base_ww = base_w[base_ow].weight_base; + + const int base_wh_size = base_h[base_oh].weight_size; + const int base_ww_size = base_w[base_ow].weight_size; + + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int K = gemm_params->K; + + B += c_col * K; + + const int4 offsets_a(0, c_row, base_oh, base_ow); + const int2 offsets_b(0, c_col); + + // Prepare threadgroup loading operations + loader_a_t loader_a( + A, + As, + offsets_a, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + loader_b_t loader_b( + B, + Bs, + offsets_b, + params, + jump_params, + base_wh, + base_ww, + simd_gid, + simd_lid); + + // Prepare threadgroup mma operation + mma_t mma_op(simd_gid, simd_lid); + + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + // Load elements into threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + + threadgroup_barrier(mem_flags::mem_none); + + // Store results to device memory + { + // Adjust for simdgroup and thread location + int offset_m = c_row + mma_op.sm; + int offset_n = c_col + mma_op.sn; + C += offset_n; + + if (offset_n >= gemm_params->N) + return; + + short diff = gemm_params->N - offset_n; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < mma_t::TM; i++) { + int cm = offset_m + i * mma_t::TM_stride; + + int n = cm / jump_params->adj_out_hw; + int hw = cm % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; + + if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { + int offset_cm = n * params->out_strides[0] + + oh * params->out_strides[1] + ow * params->out_strides[2]; + + STEEL_PRAGMA_UNROLL + for (int j = 0; j < mma_t::TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = mma_op.Ctile.frag_at(i, j); + int offset = offset_cm + (j * mma_t::TN_stride); + + constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; + + // Apply epilogue and output C + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * mma_t::TN_stride + k) < diff) { + C[offset + k] = Epilogue::apply(accum[k]); + } + } + } + } + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/loader.h b/dist/include/mlx/backend/metal/kernels/steel/conv/loader.h new file mode 100644 index 0000000..f84a640 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/loader.h @@ -0,0 +1,6 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h" +#include "mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h" \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h b/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h new file mode 100644 index 0000000..d52642b --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h @@ -0,0 +1,451 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderLargeFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderLargeFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h * params->kdil[0]; + int iw = read_iw[i] + weight_w * params->kdil[1]; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallFilter { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + using mask_t = short; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + mask_t mask_h[n_rows]; + mask_t mask_w[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallFilter( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_h(0), + weight_w(0) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Adjust for flip + if (params->flip) { + ih += (params->wS[0] - 1) * params->kdil[0]; + iw += (params->wS[1] - 1) * params->kdil[1]; + } + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2] + bj; + } + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + mask_h[i] = 0; + mask_w[i] = 0; + } + + for (short kh = 0; kh < params->wS[0]; kh++) { + short flip_h = params->flip ? params->wS[0] - kh - 1 : kh; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int n = read_n[i]; + int ih = read_ih[i] + flip_h * params->kdil[0]; + + bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0]; + + mask_h[i] |= (in_bounds << kh); + } + } + + for (short kw = 0; kw < params->wS[1]; kw++) { + short flip_w = params->flip ? params->wS[1] - kw - 1 : kw; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int iw = read_iw[i] + flip_w * params->kdil[1]; + + bool in_bounds = iw >= 0 && iw < params->iS[1]; + + mask_w[i] |= (in_bounds << kw); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + mask_t h_mask = mask_t(1) << weight_h; + mask_t w_mask = mask_t(1) << weight_w; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Read from input if in bounds + if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = src[i][j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_w < params->wS[1]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_w; + } + + return; + } + + weight_w = 0; + + if (++weight_h < params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_h; + } + + return; + } + + weight_h = 0; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += gemm_params->inp_jump_c; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoader { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + int weight_step; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoader( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + weight_hw(0), + weight_step(params->C / params->groups), + read_n(offsets.y + bi), + do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((read_n + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + if (++weight_hw < (params->wS[1] * params->wS[0])) { + src += weight_step; + return; + } + + weight_hw = 0; + + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h b/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h new file mode 100644 index 0000000..2312e1c --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h @@ -0,0 +1,319 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +#include "mlx/backend/metal/kernels/steel/conv/params.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct ChannelHelper { + STEEL_CONST short n_channels = n_channels_; + STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8; + STEEL_CONST short excess = vec_size - n_channels_; +}; + +template <> +struct ChannelHelper<1> { + STEEL_CONST short n_channels = 1; + STEEL_CONST short vec_size = 1; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<2> { + STEEL_CONST short n_channels = 2; + STEEL_CONST short vec_size = 2; + STEEL_CONST short excess = 0; +}; + +template <> +struct ChannelHelper<3> { + STEEL_CONST short n_channels = 3; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 1; +}; + +template <> +struct ChannelHelper<4> { + STEEL_CONST short n_channels = 4; + STEEL_CONST short vec_size = 4; + STEEL_CONST short excess = 0; +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant ImplicitGemmConv2DParams* gemm_params; + + int weight_hw; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + gemm_params(gemm_params_), + weight_hw(thread_idx % TCOLS) { + int out_n_pixels = params->oS[0] * params->oS[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / out_n_pixels; + int hw = offset_nhw % out_n_pixels; + int oh = hw / params->oS[1]; + int ow = hw % params->oS[1]; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + + iw * params->in_strides[2]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + int wh = (weight_hw / params->wS[1]); + int ww = (weight_hw % params->wS[1]); + + int flip_h = params->flip ? params->wS[0] - wh - 1 : wh; + int flip_w = params->flip ? params->wS[1] - ww - 1 : ww; + + int weight_h = flip_h * params->kdil[0]; + int weight_w = flip_w * params->kdil[1]; + + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + int ih = read_ih[i] + weight_h; + int iw = read_iw[i] + weight_w; + + // Read from input if in bounds + if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && + (iw >= 0 && iw < params->iS[1])) { + const device T* curr_src = src[i] + weight_h * params->in_strides[1] + + weight_w * params->in_strides[2]; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; ++j) { + dst[is * dst_ld + j] = curr_src[j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short n_channels, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderSmallChannels { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = ChannelHelper::vec_size; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + + int weight_hw; + + const int read_n; + const bool do_read; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderSmallChannels( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant ImplicitGemmConv2DParams* gemm_params_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld), + params(params_), + weight_hw(thread_idx % TCOLS), + read_n(offsets.y + bi), + do_read(read_n + BN <= gemm_params_->N) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + if (bi >= BROWS || bj >= BCOLS) + return; + + if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + + return; + } + + const device T* curr_src = src + weight_hw * (params->C / params->groups); + + if (BN != 8 || do_read) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } else { + for (short i = 0; i < BROWS; i += TROWS) { + if (((read_n + i) < params->O)) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < n_channels; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + + STEEL_PRAGMA_UNROLL + for (short j = n_channels; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_hw += TCOLS; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h b/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h new file mode 100644 index 0000000..9b7ddc2 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h @@ -0,0 +1,381 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DInputBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BM; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const device T* src[n_rows]; + + int read_n[n_rows]; + int read_ih[n_rows]; + int read_iw[n_rows]; + + /* Constructor */ + METAL_FUNC Conv2DInputBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int4 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; ++i) { + int offset_nhw = offsets.y + bi + i * TROWS; + int n = offset_nhw / jump_params->adj_out_hw; + int hw = offset_nhw % jump_params->adj_out_hw; + int oh = + (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z; + int ow = + (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w; + + int ih = oh * params->str[0] - params->pad[0]; + int iw = ow * params->str[1] - params->pad[1]; + + read_n[i] = n; + read_ih[i] = ih; + read_iw[i] = iw; + + // Read from input if in bounds + src[i] = src_ + n * params->in_strides[0] + bj; + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + METAL_FUNC void load_safe(const short remaining_k) const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < n_rows; i++) { + src[i] += BK; + } + } +}; + +template < + typename T, + short BM, + short BN, + short BK, + short tgp_size, + short tgp_padding = 0> +struct Conv2DWeightBlockLoaderGeneral { + // Destination dimensions + STEEL_CONST short BROWS = BN; + STEEL_CONST short BCOLS = BK; + + // Read dimensions + STEEL_CONST short dst_ld = BCOLS + tgp_padding; + STEEL_CONST short vec_size = + (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); + + // Thread read shape + STEEL_CONST short TCOLS = BCOLS / vec_size; + STEEL_CONST short TROWS = tgp_size / TCOLS; + + // Rows / strided reads within the block + STEEL_CONST short n_rows = BROWS / TROWS; + + // Leading dimension for src + const int src_ld; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + const constant MLXConvParams<2>* params; + const constant Conv2DGeneralJumpParams* jump_params; + + const short base_wh; + const short base_ww; + + short weight_h; + short weight_w; + + const int start_row; + + /* Constructor */ + METAL_FUNC Conv2DWeightBlockLoaderGeneral( + const device T* src_, + threadgroup T* dst_, + const int2 offsets, + const constant MLXConvParams<2>* params_, + const constant Conv2DGeneralJumpParams* jump_params_, + const short base_wh_, + const short base_ww_, + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(params_->wt_strides[0]), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj), + params(params_), + jump_params(jump_params_), + base_wh(base_wh_), + base_ww(base_ww_), + weight_h(base_wh_), + weight_w(base_ww_), + start_row(offsets.y + bi) {} + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + weight_w += jump_params->f_wgt_jump_w; + if (weight_w < params->wS[1]) { + return; + } + + weight_w = base_ww; + + weight_h += jump_params->f_wgt_jump_h; + if (weight_h < params->wS[0]) { + return; + } + + weight_h = base_wh; + + src += BK; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/conv/params.h b/dist/include/mlx/backend/metal/kernels/steel/conv/params.h new file mode 100644 index 0000000..61b8474 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/conv/params.h @@ -0,0 +1,62 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +template +struct MLXConvParams { + const int N; // Batch size + const int C; // In channels + const int O; // Out channels + const int iS[NDIM]; // Input spatial dim + const int wS[NDIM]; // Weight spatial dim + const int oS[NDIM]; // Output spatial dim + const int str[NDIM]; // Kernel strides + const int pad[NDIM]; // Input padding + const int kdil[NDIM]; // Kernel dilation + const int idil[NDIM]; // Input dilation + const int64_t in_strides[NDIM + 2]; // In strides + const int64_t wt_strides[NDIM + 2]; // Wt strides + const int64_t out_strides[NDIM + 2]; // Out strides + const int groups; // Input channel groups + const bool flip; +}; + +namespace mlx { +namespace steel { + +struct ImplicitGemmConv2DParams { + const int M; + const int N; + const int K; + + const int gemm_k_iterations; + + const int inp_jump_w; + const int inp_jump_h; + const int inp_jump_c; + + const int tiles_n; + const int tiles_m; + const int swizzle_log; +}; + +struct Conv2DGeneralJumpParams { + const int f_wgt_jump_h; + const int f_wgt_jump_w; + + const int f_out_jump_h; + const int f_out_jump_w; + + const int adj_out_h; + const int adj_out_w; + const int adj_out_hw; + const int adj_implicit_m; +}; + +struct Conv2DGeneralBaseInfo { + int weight_base; + int weight_size; +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/defines.h b/dist/include/mlx/backend/metal/kernels/steel/defines.h new file mode 100644 index 0000000..f5657ee --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/defines.h @@ -0,0 +1,7 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#define STEEL_CONST static constant constexpr const +#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") +#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)") diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/gemm.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/gemm.h new file mode 100644 index 0000000..bbe1d96 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/gemm.h @@ -0,0 +1,295 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/loader.h" +#include "mlx/backend/metal/kernels/steel/gemm/mma.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernel class +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct LoopAlignment {}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + typename AccumType = typename AccumHelper::accum_type, + typename Epilogue = TransformNone> +struct GEMMKernel { + STEEL_CONST short tgp_padding_a = 16 / sizeof(T); + STEEL_CONST short tgp_padding_b = 16 / sizeof(T); + STEEL_CONST short tgp_mem_size_a = + transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); + STEEL_CONST short tgp_mem_size_b = + transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); + STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; + + STEEL_CONST short tgp_size = WM * WN * 32; + + using loader_a_t = BlockLoader< + T, + transpose_a ? BK : BM, + transpose_a ? BM : BK, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + !transpose_a, + tgp_size>; + using loader_b_t = BlockLoader< + T, + transpose_b ? BN : BK, + transpose_b ? BK : BN, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + transpose_b, + tgp_size>; + using mma_t = BlockMMA< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, + transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, + AccumType, + Epilogue>; + + /* Main kernel function */ + template + static METAL_FUNC void gemm_loop( + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + const int gemm_k_iterations, + thread loader_a_t& loader_a, + thread loader_b_t& loader_b, + thread mma_t& mma_op, + thread const short& tgp_bm, + thread const short& tgp_bn, + thread const short& lbk, + LoopAlignment l = {}) { + // Appease the compiler + (void)l; + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned_) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + /* Main kernel function */ + static METAL_FUNC void run( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* D [[buffer(2)]], + const constant GEMMParams* params [[buffer(3)]], + threadgroup T* As [[threadgroup(0)]], + threadgroup T* Bs [[threadgroup(1)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Pacifying compiler + (void)lid; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; + + if (tgp_bm == BM && tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result(D, params->ldd); + return; + + } else if (tgp_bn == BN) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else if (tgp_bm == BM) { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + + } else { + gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk); + + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + return; + } + } + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h new file mode 100644 index 0000000..04d3b6a --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h @@ -0,0 +1,156 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/gemm/nax.h" +#include "mlx/backend/metal/kernels/steel/gemm/params.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils.h" + +using namespace metal; + +namespace mlx::steel { + +template < + typename T, + short SM, + short SN, + short SK, + short BK, + bool transpose_a, + bool transpose_b, + bool kAlignedM, + bool kAlignedN, + bool kAlignedK, + short UM, + short UN, + short UK, + typename AccumType = float> +auto gemm_loop( + const device T* A, + const device T* B, + const constant GEMMParams* params [[buffer(4)]], + const short sgp_sm, + const short sgp_sn) { + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + constexpr short TK = SK / UK; + + constexpr int RA = transpose_a ? TK : TM; + constexpr int CA = transpose_a ? TM : TK; + + constexpr int RB = transpose_b ? TN : TK; + constexpr int CB = transpose_b ? TK : TN; + + using DSubTile = NAXSubTile; + using ASubTile = + NAXSubTile; + using BSubTile = + NAXSubTile; + + NAXTile Dtile; + Dtile.clear(); + + int gemm_k_iterations_ = params->gemm_k_iterations_aligned; + + STEEL_PRAGMA_NO_UNROLL + for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) { + threadgroup_barrier(mem_flags::mem_none); + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < BK; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + const int k = kk1; + + volatile int compiler_barrier; + + const int A_offset = transpose_a ? k * params->lda : k; + const int B_offset = transpose_b ? k : k * params->ldb; + + if constexpr (kAlignedM) { + Atile.load(A + A_offset, params->lda); + } else { + const short rmax = transpose_a ? SK : sgp_sm; + const short cmax = transpose_a ? sgp_sm : SK; + Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax)); + } + + if constexpr (kAlignedN) { + Btile.load(B + B_offset, params->ldb); + } else { + const short rmax = transpose_b ? sgp_sn : SK; + const short cmax = transpose_b ? SK : sgp_sn; + Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax)); + } + + tile_matmad_nax( + Dtile, + Atile, + metal::bool_constant{}, + Btile, + metal::bool_constant{}); + + (void)compiler_barrier; + } + + A += transpose_a ? (BK * params->lda) : BK; + B += transpose_b ? BK : (BK * params->ldb); + } + + if constexpr (!kAlignedK) { + simdgroup_barrier(mem_flags::mem_none); + + const short rem_bk = params->K - gemm_k_iterations_ * BK; + + STEEL_PRAGMA_NO_UNROLL + for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) { + NAXTile Atile; + NAXTile Btile; + + STEEL_PRAGMA_UNROLL + for (int mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (int nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (int kk = 0; kk < TK; kk++) { + const int m = mm * UM; + const int n = nn * UN; + const int k = kk1 + kk * UK; + const short psk = max(0, rem_bk - k); + + const int A_offset = + transpose_a ? (m + k * params->lda) : (m * params->lda + k); + const int B_offset = + transpose_b ? (k + n * params->ldb) : (k * params->ldb + n); + + { + const short psm = kAlignedM ? SM : max(0, sgp_sm - m); + const short rmax = transpose_a ? psk : psm; + const short cmax = transpose_a ? psm : psk; + Atile.load_safe(A + A_offset, params->lda, short2(cmax, rmax)); + } + + { + const short psn = kAlignedN ? SN : max(0, sgp_sn - n); + const short rmax = transpose_b ? psn : psk; + const short cmax = transpose_b ? psk : psn; + Btile.load_safe(B + B_offset, params->ldb, short2(cmax, rmax)); + } + + subtile_matmad_nax( + Dtile.subtile_at(mm, nn), + Atile.subtile_at(0, 0), + metal::bool_constant{}, + Btile.subtile_at(0, 0), + metal::bool_constant{}); + } + } + } + } + } + + return Dtile; +} + +} // namespace mlx::steel diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h new file mode 100644 index 0000000..8583087 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h @@ -0,0 +1,346 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on + // Pacifying compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + const TransformAdd epilogue_op_add( + addmm_params->alpha, addmm_params->beta); + const TransformAxpby epilogue_op_axpby( + addmm_params->alpha, addmm_params->beta); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (align_M && align_N) { + // Do gemm + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + const int leftover_bk = 0; + + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + // Do gemm + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); + } else { + mma_op.apply_epilogue( + C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result(D, params->ldd); + + } else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + + // Do epilogue + if (use_out_source) { + if (do_axpby) { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_axpby); + } else { + mma_op.apply_epilogue_safe( + C, + addmm_params->ldc, + addmm_params->fdc, + short2(tgp_bn, tgp_bm), + epilogue_op_add); + } + } + + // Store results to device memory + return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h new file mode 100644 index 0000000..44328ed --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h @@ -0,0 +1,207 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; + +constant bool use_out_source [[function_constant(100)]]; +constant bool do_axpby [[function_constant(110)]]; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +// clang-format off +template < + bool kAlignedM, + bool kAlignedN, + typename NAXTile_t, + typename T> +void gemm_epilogue( + thread NAXTile_t& Dtile, + const device T* C, + const constant GEMMParams* params, + const constant GEMMAddMMParams* addmm_params, + const short sgp_sm, + const short sgp_sn) { // clang-format on + + (void)params; + + constexpr short UM = NAXTile_t::kSubTileRows; + constexpr short UN = NAXTile_t::kSubTileCols; + using CSubTile = NAXSubTile; + + using V = typename NAXTile_t::elem_type; + + constexpr short TM = NAXTile_t::kTileRows; + constexpr short TN = NAXTile_t::kTileCols; + constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile; + + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + const short m = mm * UM; + const short n = nn * UN; + + CSubTile CTile; + + if constexpr (kAlignedM && kAlignedN) { + CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n); + } else { + CTile.load_safe( + C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n); + } + + auto delems = Dtile.subtile_at(mm, nn).elems(); + auto celems = CTile.elems(); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemsPerSubTile; i++) { + if (do_axpby) { + delems[i] = addmm_params->alpha * delems[i] + + addmm_params->beta * static_cast(celems[i]); + } else { + delems[i] += static_cast(celems[i]); + } + } + } + } +} + +// clang-format off +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device T* C [[buffer(2), function_constant(use_out_source)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on + // Find block + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + // Exit early if out of bounds + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Adjust for batch + if (has_batch) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + if (use_out_source) { + const constant auto* C_bstrides = B_bstrides + params->batch_ndim; + C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); + } + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + + if (use_out_source) { + C += addmm_params->batch_stride_c * tid.z; + } + } + + D += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + if (use_out_source) { + C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; + } + + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn))); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + D += tm * params->ldd + tn; + + if (use_out_source) { + C += tm * addmm_params->ldc + tn * addmm_params->fdc; + } + + using DSubTile = NAXSubTile; + NAXTile Dtile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + Dtile = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>(A, B, params, sgp_sm, sgp_sn); + if (use_out_source) { + gemm_epilogue( + Dtile, C, params, addmm_params, sgp_sm, sgp_sn); + } + if constexpr (kAlignedM && kAlignedN) { + Dtile.store(D, int(params->ldd)); + } else { + Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm)); + } + }); + }); + }); +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h new file mode 100644 index 0000000..4493375 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h @@ -0,0 +1,459 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool has_batch [[function_constant(10)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm_rhs( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[c_row]; + short offset_next = 0; + int n = 0; + while (n < tgp_bm) { + n++; + offset = offset_next; + index = index_next; + offset_next = tgp_bm; + for (; n < tgp_bm; n++) { + if (rhs_indices[c_row + n] != index) { + offset_next = n; + index_next = rhs_indices[c_row + n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b( + B + index * params->batch_stride_b, + params->ldb, + Bs, + simd_group_id, + simd_lane_id); + + // Prepare iterations + const int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + if (offset_next - offset == BM) { + mma_op.store_result(C, params->ldd); + } else { + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(BN, offset_next)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_slice( + C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); + } + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gather_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* lhs_indices [[buffer(2)]], + const device uint32_t* rhs_indices [[buffer(3)]], + device T* C [[buffer(4)]], + const constant GEMMParams* params [[buffer(5)]], + const constant int* indices_shape [[buffer(6)]], + const constant int64_t* lhs_strides [[buffer(7)]], + const constant int64_t* rhs_strides [[buffer(8)]], + const constant int& batch_ndim_a [[buffer(9)]], + const constant int* batch_shape_a [[buffer(10)]], + const constant int64_t* batch_strides_a [[buffer(11)]], + const constant int& batch_ndim_b [[buffer(12)]], + const constant int* batch_shape_b [[buffer(13)]], + const constant int64_t* batch_strides_b [[buffer(14)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Move A and B to the locations pointed by lhs_indices and rhs_indices. + uint32_t indx_A, indx_B; + if (has_batch) { + ulong2 indices_offsets = elem_to_loc_broadcast( + tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim); + indx_A = lhs_indices[indices_offsets.x]; + indx_B = rhs_indices[indices_offsets.y]; + } else { + indx_A = lhs_indices[params->batch_stride_a * tid.z]; + indx_B = rhs_indices[params->batch_stride_b * tid.z]; + } + A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a); + B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b); + C += params->batch_stride_d * tid.z; + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Just make sure everybody's finished with the indexing math above. + threadgroup_barrier(mem_flags::mem_none); + + // Find block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Prepare iterations + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + // Do unaligned K iterations first + if (!align_K) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + // Move loader source ahead to end + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + + // Matrix level aligned never check + if (align_M && align_N) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + // Store results to device memory + mma_op.store_result(C, params->ldd); + } else { + const short lbk = 0; + + // Tile aligned don't check + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + lbk, + LoopAlignment{}); + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h new file mode 100644 index 0000000..2928583 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h @@ -0,0 +1,132 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; +constant bool align_K [[function_constant(202)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +gather_mm_rhs_nax( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* rhs_indices [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + constexpr short UM = 16; + constexpr short UN = 32; + constexpr short UK = 16; + constexpr short SM = BM / WM; + constexpr short SN = BN / WN; + constexpr short SK = 32; + constexpr short TM = SM / UM; + constexpr short TN = SN / UN; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + rhs_indices += c_row; + + const short tm = SM * (simd_group_id / WN); + const short tn = SN * (simd_group_id % WN); + + const short sgp_sm = align_M ? SM : min(SM, short(params->M - (c_row + tm))); + const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); + + const short sgp_sn = align_N ? SN : min(SN, short(params->N - (c_col + tn))); + const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); + + A += transpose_a ? tm : (tm * params->lda); + B += transpose_b ? (tn * params->ldb) : tn; + C += tm * params->ldd + tn; + rhs_indices += tm; + + // Do as many matmuls as necessary + uint32_t index; + short offset; + uint32_t index_next = rhs_indices[0]; + short offset_next = 0; + int n = 0; + while (n < sgp_sm) { + n++; + offset = offset_next; + index = index_next; + offset_next = sgp_sm; + for (; n < sgp_sm; n++) { + if (rhs_indices[n] != index) { + offset_next = n; + index_next = rhs_indices[n]; + break; + } + } + threadgroup_barrier(mem_flags::mem_none); + + using DSubTile = NAXSubTile; + NAXTile Ctile; + + dispatch_bool(align_K, [&](auto kAlignedK) { + dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { + dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { + auto do_gemm = gemm_loop< + T, + SM, + SN, + SK, + BK, + transpose_a, + transpose_b, + kAlignedM.value, + kAlignedN.value, + kAlignedK.value, + UM, + UN, + UK, + AccumType>; + Ctile = do_gemm( + A, B + index * params->batch_stride_b, params, sgp_sm, sgp_sn); + + if constexpr (kAlignedN.value) { + if (offset_next - offset == SM) { + Ctile.store(C, int(params->ldd)); + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(SN, offset_next)); + } + } else { + Ctile.store_slice( + C, + int(params->ldd), + short2(0, offset), + short2(sgp_sn, offset_next)); + } + }); + }); + }); + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h new file mode 100644 index 0000000..c8ffe2b --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h @@ -0,0 +1,719 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/backend/metal/kernels/steel/defines.h" +using namespace metal; +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +struct _NoMask { + char x; + + constexpr METAL_FUNC operator bool() { + return true; + } + constexpr METAL_FUNC operator bool() const threadgroup { + return true; + } + constexpr METAL_FUNC operator bool() const device { + return true; + } + constexpr METAL_FUNC operator bool() const constant { + return true; + } +}; + +template +struct ScaleOp { + OutT scale; + + METAL_FUNC OutT apply(InT x) const { + return static_cast(x) * scale; + } +}; + +typedef struct _NoMask nomask_t; + +template < + typename T, + typename out_mask_t, + typename op_mask_t, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant int64_t* batch_strides [[buffer(7)]], + const device out_mask_t* out_mask [[buffer(10)]], + const device op_mask_t* lhs_mask [[buffer(11)]], + const device op_mask_t* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + static_assert( + BM == BN, + "block_masked_gemm must have the same block M and block N size"); + static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); + + constexpr bool has_operand_mask = !metal::is_same_v; + constexpr bool has_output_mask = !metal::is_same_v; + + constexpr bool has_mul_operand_mask = + has_operand_mask && !metal::is_same_v; + constexpr bool has_mul_output_mask = + has_output_mask && !metal::is_same_v; + + constexpr short k_mask_factor = short(BM / BK); + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + const constant auto* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + + if (params->batch_ndim > 1) { + if (has_output_mask) { + out_mask += elem_to_loc( + tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + const constant auto* mask_strides_lhs = mask_batch_strides; + const constant auto* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + if (has_output_mask) { + out_mask += tid.z * mask_batch_strides[0]; + mask_batch_strides += params->batch_ndim; + } + + if (has_operand_mask) { + lhs_mask += tid.z * mask_batch_strides[0]; + rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + const constant int* out_mask_strides = mask_strides; + const constant int* lhs_mask_strides = + mask_strides + (has_output_mask ? 2 : 0); + const constant int* rhs_mask_strides = + lhs_mask_strides + (has_operand_mask ? 2 : 0); + + const int out_mask_offset = !has_output_mask + ? 0 + : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; + int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; + int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; + const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; + const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; + short k_factor_cnt = k_mask_factor; + + ScaleOp out_mask_op; + ScaleOp lhs_mask_op; + ScaleOp rhs_mask_op; + + if (has_output_mask) { + auto mask_out = out_mask[out_mask_offset]; + + if (has_mul_output_mask) { + out_mask_op.scale = float(mask_out); + } + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup bounds + const short tgp_bm = + MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); + const short tgp_bn = + MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + /////////////////////////////////////////////////////////////////////////////// + // Do unaligned K iterations first + if (!K_aligned) { + const int k_last = params->gemm_k_iterations_aligned * BK; + const int mask_idx_last = k_last / BM; + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && + bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = + lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; + rhs_mask_op.scale = + rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; + } + + // Move loader source ahead to end + const int k_remain = params->K - k_last; + const size_t k_jump_a = + transpose_a ? params->lda * size_t(k_last) : size_t(k_last); + const size_t k_jump_b = + transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); + + loader_a.src += k_jump_a; + loader_b.src += k_jump_b; + + // Load tile + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Do matmul + mma_op.mma(As, Bs); + + // Reset source back to start + loader_a.src -= k_jump_a; + loader_b.src -= k_jump_b; + } + } + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { + const bool M_aligned = (tgp_bm == BM); + const bool N_aligned = (tgp_bn == BN); + + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (; gemm_k_iterations > 0; gemm_k_iterations--) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (bool(lhs_mask[lhs_mask_offset]) && + bool(rhs_mask[rhs_mask_offset]))) { + if (has_mul_operand_mask) { + lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; + rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; + } + + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + if (has_mul_operand_mask) { + loader_a.apply_inplace_op(lhs_mask_op); + loader_b.apply_inplace_op(rhs_mask_op); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + + k_factor_cnt--; + lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; + rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; + k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; + } + + if (has_mul_output_mask) { + mma_op.apply_epilogue(out_mask_op); + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned, + bool has_operand_mask = false> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void +block_masked_gemm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device T* D [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + const constant int* batch_shape [[buffer(6)]], + const constant int64_t* batch_strides [[buffer(7)]], + const device bool* out_mask [[buffer(10)]], + const device bool* lhs_mask [[buffer(11)]], + const device bool* rhs_mask [[buffer(12)]], + const constant int* mask_strides [[buffer(13)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + // Appease the compiler + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + + const int tid_y = ((tid.y) << params->swizzle_log) + + ((tid.x) & ((1 << params->swizzle_log) - 1)); + const int tid_x = (tid.x) >> params->swizzle_log; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + if (params->batch_ndim > 1) { + const constant auto* mask_batch_strides = + batch_strides + 2 * params->batch_ndim; + out_mask += + elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); + + if (has_operand_mask) { + const constant auto* mask_strides_lhs = + mask_batch_strides + params->batch_ndim; + const constant auto* mask_strides_rhs = + mask_strides_lhs + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, + batch_shape, + mask_strides_lhs, + mask_strides_rhs, + params->batch_ndim); + + lhs_mask += batch_offsets.x; + rhs_mask += batch_offsets.y; + } + } else { + out_mask += tid.z * batch_strides[2 * params->batch_ndim]; + if (has_operand_mask) { + lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; + rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; + } + } + + // Adjust for batch + if (params->batch_ndim > 1) { + const constant auto* A_bstrides = batch_strides; + const constant auto* B_bstrides = batch_strides + params->batch_ndim; + + ulong2 batch_offsets = elem_to_loc_broadcast( + tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); + + A += batch_offsets.x; + B += batch_offsets.y; + + } else { + A += params->batch_stride_a * tid.z; + B += params->batch_stride_b * tid.z; + } + + D += params->batch_stride_d * tid.z; + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + D += c_row_long * params->ldd + c_col_long; + + bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; + + // Write zeros and return + if (!mask_out) { + constexpr short tgp_size = WM * WN * 32; + constexpr short vec_size = 4; + + // Tile threads in threadgroup + constexpr short TN = BN / vec_size; + constexpr short TM = tgp_size / TN; + + const short thread_idx = simd_group_id * 32 + simd_lane_id; + const short bi = thread_idx / TN; + const short bj = vec_size * (thread_idx % TN); + + D += bi * params->ldd + bj; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + for (short ti = 0; ti < BM; ti += TM) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } else { + short jmax = tgp_bn - bj; + jmax = jmax < vec_size ? jmax : vec_size; + for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { + for (short j = 0; j < jmax; j++) { + D[ti * params->ldd + j] = T(0.); + } + } + } + + return; + } + + threadgroup_barrier(mem_flags::mem_none); + + // Prepare threadgroup mma operation + thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Prepare threadgroup loading operations + thread typename gemm_kernel::loader_a_t loader_a( + A, params->lda, As, simd_group_id, simd_lane_id); + thread typename gemm_kernel::loader_b_t loader_b( + B, params->ldb, Bs, simd_group_id, simd_lane_id); + + /////////////////////////////////////////////////////////////////////////////// + // MNK aligned loop + if (MN_aligned) { + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + + // Loop tail + if (!K_aligned) { + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + int lbk = params->K - params->gemm_k_iterations_aligned * BK; + short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); + short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); + + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + // Store results to device memory + mma_op.store_result(D, params->ldd); + return; + + } + /////////////////////////////////////////////////////////////////////////////// + // MN unaligned loop + else { // Loop over K - unaligned case + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short lbk = params->K - params->gemm_k_iterations_aligned * BK; + + bool M_aligned = (tgp_bm == BM); + bool N_aligned = (tgp_bn == BN); + + short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); + short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && + rhs_mask + [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { + // Load elements into threadgroup + if (M_aligned) { + loader_a.load_unsafe(); + } else { + loader_a.load_safe(tile_dims_A); + } + + if (N_aligned) { + loader_b.load_unsafe(); + } else { + loader_b.load_safe(tile_dims_B); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + } + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + + if (!K_aligned) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (!has_operand_mask || + (lhs_mask + [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && + rhs_mask + [(params->K / BM) * mask_strides[5] + + tid_x * mask_strides[4]])) { + short2 tile_dims_A_last = + transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); + short2 tile_dims_B_last = + transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); + + loader_a.load_safe(tile_dims_A_last); + loader_b.load_safe(tile_dims_B_last); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + mma_op.mma(As, Bs); + } + } + + if (M_aligned && N_aligned) { + mma_op.store_result(D, params->ldd); + } else { + mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h new file mode 100644 index 0000000..b915eb3 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h @@ -0,0 +1,266 @@ +// Copyright © 2025 Apple Inc. + +using namespace mlx::steel; + +constant bool segments_contiguous [[function_constant(199)]]; +constant bool align_M [[function_constant(200)]]; +constant bool align_N [[function_constant(201)]]; + +template < + typename T, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + typename AccumType = float> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void segmented_mm( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + const device uint32_t* segments [[buffer(2)]], + device T* C [[buffer(3)]], + const constant GEMMParams* params [[buffer(4)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]]) { + using gemm_kernel = GEMMKernel< + T, + T, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + true, + true, + AccumType>; + + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + if (params->tiles_n <= static_cast(tid.x) || + params->tiles_m <= static_cast(tid.y)) { + return; + } + + // Prepare threadgroup memory + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + // Find the block in A, B, C + const int c_row = tid.y * BM; + const int c_col = tid.x * BN; + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + + // Prepare threadgroup bounds + const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); + const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); + + // Move the pointers to the output tile + A += transpose_a ? c_row_long : c_row_long * params->lda; + B += transpose_b ? c_col_long * params->ldb : c_col_long; + C += c_row_long * params->ldd + c_col_long; + + // Move the pointers to the start of the segment + uint32_t k_start, k_end; + if (segments_contiguous) { + k_start = segments[2 * tid.z]; + k_end = segments[2 * tid.z + 1]; + } else { + // We accept either contiguous (above) or weird strides where the beginning + // of the next one is the previous one. Basically the last two strides are + // both 1! + k_start = segments[tid.z]; + k_end = segments[tid.z + 1]; + } + A += transpose_a ? k_start * params->lda : k_start; + B += transpose_b ? k_start : k_start * params->ldb; + C += tid.z * params->batch_stride_d; + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Matrix level alignment so only check K + if (align_M && align_N) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } else { + // Tile aligned do the same as above + if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result(C, params->ldd); + } + + // Tile partially aligned check rows + else if (align_N || tgp_bn == BN) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Tile partially aligned check cols + else if (align_M || tgp_bm == BM) { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + + // Nothing aligned so check both rows and cols + else { + uint32_t k = k_start + BK; + for (; k <= k_end; k += BK) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Load elements into threadgroup + loader_a.load_safe( + transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); + loader_b.load_safe( + transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + short k_remain = BK - short(k - k_end); + const short2 tile_dims_A = + transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); + const short2 tile_dims_B = + transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); + if (k_remain > 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(tile_dims_A); + loader_b.load_safe(tile_dims_B); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + } + mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); + } + } +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h new file mode 100644 index 0000000..1ff97ea --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h @@ -0,0 +1,227 @@ +// Copyright © 2024 Apple Inc. + +using namespace mlx::steel; + +/////////////////////////////////////////////////////////////////////////////// +// GEMM kernels +/////////////////////////////////////////////////////////////////////////////// + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + bool MN_aligned, + bool K_aligned> +[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk( + const device T* A [[buffer(0)]], + const device T* B [[buffer(1)]], + device U* C [[buffer(2)]], + const constant GEMMSpiltKParams* params [[buffer(3)]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]], + uint3 tid [[threadgroup_position_in_grid]], + uint3 lid [[thread_position_in_threadgroup]]) { + (void)lid; + + using gemm_kernel = GEMMKernel< + T, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + MN_aligned, + K_aligned>; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + + const int tid_x = tid.x; + const int tid_y = tid.y; + const int tid_z = tid.z; + + if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { + return; + } + + // Find block in A, B, C + const int c_row = tid_y * BM; + const int c_col = tid_x * BN; + const int k_start = params->split_k_partition_size * tid_z; + + const size_t c_row_long = size_t(c_row); + const size_t c_col_long = size_t(c_col); + const size_t k_start_long = size_t(k_start); + + A += transpose_a ? (c_row_long + k_start_long * params->lda) + : (k_start_long + c_row_long * params->lda); + B += transpose_b ? (k_start_long + c_col_long * params->ldb) + : (c_col_long + k_start_long * params->ldb); + C += (size_t(params->split_k_partition_stride) * tid_z) + + (c_row_long * params->ldc + c_col_long); + + // Prepare threadgroup loading operations + thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); + thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); + + // Prepare threadgroup mma operation + thread mma_t mma_op(simd_group_id, simd_lane_id); + + int gemm_k_iterations = params->gemm_k_iterations_aligned; + + short tgp_bm = min(BM, params->M - c_row); + short tgp_bn = min(BN, params->N - c_col); + short leftover_bk = params->K % BK; + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bn == BN) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else if (tgp_bm == BM) { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } else { + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iterations, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if ((tid_z + 1) == (params->split_k_partitions)) { + int gemm_k_iter_remaining = + (params->K - (k_start + params->split_k_partition_size)) / BK; + if (!K_aligned || gemm_k_iter_remaining > 0) + gemm_kernel::gemm_loop( + As, + Bs, + gemm_k_iter_remaining, + loader_a, + loader_b, + mma_op, + tgp_bm, + tgp_bn, + leftover_bk, + LoopAlignment{}); + } + + if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { + mma_op.store_result(C, params->ldc); + } else { + mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Split k accumulation kernel +/////////////////////////////////////////////////////////////////////////////// + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformNone> +[[kernel]] void gemm_splitk_accum( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + D[0] = Epilogue::apply(out); +} + +template < + typename AccT, + typename OutT, + typename Epilogue = TransformAxpby> +[[kernel]] void gemm_splitk_accum_axpby( + const device AccT* C_split [[buffer(0)]], + device OutT* D [[buffer(1)]], + const constant int& k_partitions [[buffer(2)]], + const constant int& partition_stride [[buffer(3)]], + const constant int& ldd [[buffer(4)]], + const device OutT* C [[buffer(5)]], + const constant int& ldc [[buffer(6)]], + const constant int& fdc [[buffer(7)]], + const constant float& alpha [[buffer(8)]], + const constant float& beta [[buffer(9)]], + uint2 gid [[thread_position_in_grid]]) { + // Ajust D and C + C += gid.x * size_t(fdc) + gid.y * size_t(ldc); + D += gid.x + gid.y * size_t(ldd); + C_split += gid.x + gid.y * size_t(ldd); + + size_t offset = 0; + AccT out = 0; + + for (int i = 0; i < k_partitions; i++) { + out += C_split[offset]; + offset += partition_stride; + } + + // Write output + Epilogue op(alpha, beta); + D[0] = op.apply(out, *C); +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/loader.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/loader.h new file mode 100644 index 0000000..d421b2d --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/loader.h @@ -0,0 +1,137 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/defines.h" + +/////////////////////////////////////////////////////////////////////////////// +// Loading helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template < + typename T, + short BROWS, + short BCOLS, + short dst_ld, + short reduction_dim, + short tgp_size, + short alignment = 1, + short n_reads = (BCOLS * BROWS) / (tgp_size), + short TCOLS = BCOLS / n_reads, + short TROWS = tgp_size / TCOLS> +struct BlockLoader { + STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; + STEEL_CONST short vec_size = n_reads; + + // Leading dimension for src + const int src_ld; + const int tile_stride; + + // Thread location indices + const short thread_idx; + const short bi; + const short bj; + + // threadgroup and device memory + threadgroup T* dst; + const device T* src; + + struct alignas(alignment * sizeof(T)) ReadVector { + uint8_t v[sizeof(T) * vec_size]; + }; + + /* Constructor */ + METAL_FUNC BlockLoader( + const device T* src_, + const int src_ld_, + threadgroup T* dst_, + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) + : src_ld(src_ld_), + tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), + thread_idx(simd_group_id * 32 + simd_lane_id), + bi(thread_idx / TCOLS), + bj(vec_size * (thread_idx % TCOLS)), + dst(dst_ + bi * dst_ld + bj), + src(src_ + bi * src_ld + bj) {} + + /* Apply operation to threadgroup without bound checking */ + template + METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); + } + } + } + + /* Load from device memory into threadgroup memory - without bound checking */ + METAL_FUNC void load_unsafe() const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + *((threadgroup ReadVector*)(&dst[i * dst_ld])) = + *((const device ReadVector*)(&src[i * src_ld])); + } + } + + /* Load from device memory into threadgroup memory - with bound checking */ + METAL_FUNC void load_safe(short2 src_tile_dim) const { + src_tile_dim = src_tile_dim - short2(bj, bi); + + // Skip loading if thread has no valid reads + if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + return; + } + + // Use fast thread memory for bound checks + bool tmp_idx[vec_size]; + T tmp_val[vec_size]; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BROWS; i += TROWS) { + // Make sure tmp_idx only contains valid indices + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); + } + + // Read valid indices into tmp_val + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; + } + + // Zero out unneeded values + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); + } + + // Copy values to threadgroup memory + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = tmp_val[j]; + } + } + } + + /* Iteration helper */ + METAL_FUNC void next() { + src += tile_stride; + } +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/mma.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/mma.h new file mode 100644 index 0000000..74151a9 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/mma.h @@ -0,0 +1,1146 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct BaseMMAFrag { + static_assert( + kFragRows_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); + static_assert( + kFragCols_ == 8, + "Only 8 x 8 fragment matrices are currently supported"); +}; + +template +struct BaseMMAFrag { + STEEL_CONST int kFragRows = 8; + STEEL_CONST int kFragCols = 8; + + STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST int kElemRows = 1; + STEEL_CONST int kElemCols = 2; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + typedef metal::simdgroup_matrix mat_type; + typedef metal::vec frag_type; + + METAL_FUNC static constexpr short2 get_coord(ushort simd_lane_id + [[thread_index_in_simdgroup]]) { + const short qid = simd_lane_id / 4; + const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); + const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; + return short2{fn, fm}; + } + + template + METAL_FUNC static constexpr void + load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void load_safe( + thread frag_type& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template + METAL_FUNC static constexpr void + store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_safe( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < lim_x && (off_y + j) < lim_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX, + typename OffY> + METAL_FUNC static constexpr void store_slice( + const thread frag_type& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if ((off_x + i) < stop_x && (off_x + i) >= start_x && + (off_y + j) < stop_y && (off_y + j) >= start_y) { + dst[(off_x + i) * str_x + (off_y + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + METAL_FUNC static constexpr void mma( + thread frag_type& D, + thread frag_type& A, + thread frag_type& B, + thread frag_type& C) { + mat_type D_mat; + mat_type A_mat; + mat_type B_mat; + mat_type C_mat; + + reinterpret_cast(A_mat.thread_elements()) = A; + reinterpret_cast(B_mat.thread_elements()) = B; + reinterpret_cast(C_mat.thread_elements()) = C; + + mma(D_mat, A_mat, B_mat, C_mat); + + D = reinterpret_cast(D_mat.thread_elements()); + } + + METAL_FUNC static constexpr void mma( + thread mat_type& D, + thread mat_type& A, + thread mat_type& B, + thread mat_type& C) { + simdgroup_multiply_accumulate(D, A, B, C); + } +}; + +template < + typename T, + int kTileRows_, + int kTileCols_, + class MMAFrag_ = BaseMMAFrag> +struct MMATile { + using MMAFrag_t = MMAFrag_; + using elem_type = T; + STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; + STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; + STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; + + STEEL_CONST int kTileRows = kTileRows_; + STEEL_CONST int kTileCols = kTileCols_; + + STEEL_CONST int kRows = kTileRows * kFragRows; + STEEL_CONST int kCols = kTileCols * kFragCols; + + STEEL_CONST int kNumFrags = kTileRows * kTileCols; + STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; + + typedef typename MMAFrag_t::mat_type mat_type; + typedef typename MMAFrag_t::frag_type frag_type; + + frag_type val_frags[kNumFrags] = {frag_type(0)}; + + METAL_FUNC MMATile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kTileCols + j]; + } + + METAL_FUNC mat_type mat_at(const short i, const short j) { + mat_type val_mat; + STEEL_PRAGMA_UNROLL + for (short ii = 0; ii < kElemsPerFrag; ++ii) { + val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; + } + return val_mat; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &( + src[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &( + dst[(i * kFragRows) * w_x * str_x + + (j * kFragCols) * w_y * str_y]), + Int{}, + Int{}); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::load( + frag_at(i, j), + &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + MMAFrag_t::store( + frag_at(i, j), + &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), + ld, + Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::load_safe( + frag_at(i, j), + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_safe( + frag_at(i, j), + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + MMAFrag_t::store_slice( + frag_at(i, j), + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + (i * kFragRows) * w_x, + (j * kFragCols) * w_y); + } + } + } +}; + +template +METAL_FUNC void tile_matmad( + thread MMATile& D, + thread MMATile& A, + thread MMATile& B, + thread MMATile& C) { + STEEL_PRAGMA_UNROLL + for (short m = 0; m < M; ++m) { + STEEL_PRAGMA_UNROLL + for (short n = 0; n < N; ++n) { + short n_serp = (m % 2) ? (N - 1 - n) : n; + STEEL_PRAGMA_UNROLL + for (short k = 0; k < K; ++k) { + MMATile::MMAFrag_t::mma( + D.frag_at(m, n_serp), + A.frag_at(m, k), + B.frag_at(k, n_serp), + C.frag_at(m, n_serp)); + } + } + } +} + +template +struct TransformNone { + static METAL_FUNC complex64_t apply(complex64_t x) { + return x; + } + static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) { + return x; + } +}; + +template < + typename T, + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType = float, + typename Epilogue = TransformNone> +struct BlockMMA { + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / (kFragSize * WM); + // Warp tile size along N + STEEL_CONST short TN = BN / (kFragSize * WN); + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // Simdgroup matrices + MMATile Atile; + MMATile Btile; + MMATile Ctile; + + // Offsets within threadgroup + short sm; + short sn; + + short As_offset; + short Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N + + sm += tm; + sn += tn; + } + + /* (BM, BK) X (BK, BN) multiply accumulate function */ + METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + Atile.template load(As); + + simdgroup_barrier(mem_flags::mem_none); + + Btile.template load(Bs); + + simdgroup_barrier(mem_flags::mem_none); + + tile_matmad(Ctile, Atile, Btile, Ctile); + + // Progress to next simdgroup tile + As += tile_stride_a; + Bs += tile_stride_b; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + Ctile.template store(D, ldd); + } + + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + // TODO: Check the start as well + if (stop.y <= 0 || stop.x <= 0) { + return; + } + + Ctile.template store_slice(D, ldd, start, stop); + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); + } + + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + Ctile.template store_safe(D, ldd, dst_tile_dims); + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { + Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { + accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Read C + U c_elems[kelems] = {0}; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + c_elems[k] = C[offset_c + k * fdc]; + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + accum[k] = epilogue_op.apply(accum[k], c_elems[k]); + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in C + thread const auto& accum = Ctile.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int offset_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[offset_d + k] = + epilogue_op.apply(accum[k], C[offset_c + k * fdc]); + } + } + } + } + } + } +}; + +template < + typename U, + int BM, + int BN, + int BK, + int WM, + int WN, + bool transpose_a, + bool transpose_b, + short lda_tgp, + short ldb_tgp, + typename AccumType, + typename Epilogue> +struct BlockMMA< + complex64_t, + U, + BM, + BN, + BK, + WM, + WN, + transpose_a, + transpose_b, + lda_tgp, + ldb_tgp, + AccumType, + Epilogue> { + static_assert( + metal::is_same_v, + "BlockMMA expects float accumulators"); + static_assert( + metal::is_same_v, + "For complex BlockMMA, U must be complex64_t; use a different epilogue for projections"); + // MMAFrag size + STEEL_CONST short kFragSize = 8; + using MMAFrag_acc_t = BaseMMAFrag; + + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TM_stride = kFragSize * WM; + // Warp tile simdgroup matrix strides along M + STEEL_CONST short TN_stride = kFragSize * WN; + + // Warp tile size along M + STEEL_CONST short TM = BM / (kFragSize * WM); + // Warp tile size along N + STEEL_CONST short TN = BN / (kFragSize * WN); + + // Threadgroup A strides + STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M + STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K + + // Threadgroup B strides + STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K + STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N + + // Threadgroup strides along K + STEEL_CONST short tile_stride_a = kFragSize * A_str_k; + STEEL_CONST short tile_stride_b = kFragSize * B_str_k; + + // When indexing complex as float[2] + STEEL_CONST short A_str_m_f = A_str_m * 2; + STEEL_CONST short A_str_k_f = A_str_k * 2; + STEEL_CONST short B_str_k_f = B_str_k * 2; + STEEL_CONST short B_str_n_f = B_str_n * 2; + STEEL_CONST short tile_stride_a_f = tile_stride_a * 2; + STEEL_CONST short tile_stride_b_f = tile_stride_b * 2; + + // Accumulators (real/imag) + MMATile Ctile_r; + MMATile Ctile_i; + + // Offsets within threadgroup + short sm, sn; + short As_offset, Bs_offset; + + /* Constructor */ + METAL_FUNC BlockMMA( + ushort simd_group_id [[simdgroup_index_in_threadgroup]], + ushort simd_lane_id [[thread_index_in_simdgroup]]) { + // Determine thread position in simdgroup matrix + short tm = kFragSize * (simd_group_id / WN); + short tn = kFragSize * (simd_group_id % WN); + + short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); + sm = simd_coord.y; + sn = simd_coord.x; + + // Determine thread and simdgroup offset + As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K) + Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N) + + sm += tm; + sn += tn; + } + + /* Karatsuba MMA: 3 real MMAs per K-chunk */ + METAL_FUNC void mma( + const threadgroup complex64_t* As, + const threadgroup complex64_t* Bs) { + // Adjust for simdgroup and thread location + As += As_offset; + Bs += Bs_offset; + threadgroup const float* As_f = + reinterpret_cast(As); + threadgroup const float* Bs_f = + reinterpret_cast(Bs); + + // Iterate over BK in blocks of kFragSize + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < BK; kk += kFragSize) { + simdgroup_barrier(mem_flags::mem_none); + + MMATile Ar, Ai; + Ar.template load(As_f + 0); + Ai.template load(As_f + 1); + + simdgroup_barrier(mem_flags::mem_none); + + MMATile Br, Bi; + Br.template load(Bs_f + 0); + Bi.template load(Bs_f + 1); + + simdgroup_barrier(mem_flags::mem_none); + + // P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi) + MMATile P, Q, R; + + tile_matmad(P, Ar, Br, P); + tile_matmad(Q, Ai, Bi, Q); + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i) + Ar.elems()[i] += Ai.elems()[i]; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i) + Br.elems()[i] += Bi.elems()[i]; + + tile_matmad(R, Ar, Br, R); + + // C_r += P - Q ; C_i -= Q + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) { + const auto p = P.elems()[i]; + const auto q = Q.elems()[i]; + const auto r = R.elems()[i]; + Ctile_r.elems()[i] += (p - q); + Ctile_i.elems()[i] += (r - p - q); + } + + // Progress to next simdgroup tile + As_f += tile_stride_a_f; + Bs_f += tile_stride_b_f; + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result(device U* D, const int ldd) { + // Adjust for simdgroup and thread location + D += sm * ldd + sn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off = (i * TM_stride) * ldd + (j * TN_stride); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + + METAL_FUNC void + store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { + D += sm * ldd + sn; + start -= short2(sn, sm); + stop -= short2(sn, sm); + + if (stop.y <= 0 || stop.x <= 0) + return; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + const int row = i * TM_stride; + if (row >= start.y && row < stop.y) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + const int off = row * ldd + (j * TN_stride); + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) { + const int col = j * TN_stride + k; + if (col >= start.x && col < stop.x) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + } + } + + METAL_FUNC void + store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { + D += sm * ldd + sn; + dst_tile_dims -= short2(sn, sm); + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + int off = (i * TM_stride) * ldd + (j * TN_stride); + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); + } + } + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) { + complex64_t out = epilogue_op.apply( + complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i])); + Ctile_r.elems()[i] = out.real; + Ctile_i.elems()[i] = out.imag; + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue( + const device U* C, + const int ldc, + const int fdc, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread auto& r = Ctile_r.frag_at(i, j); + thread auto& im = Ctile_i.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { + complex64_t out = epilogue_op.apply( + complex64_t(r[k], im[k]), C[offset_c + k * fdc]); + r[k] = out.real; + im[k] = out.imag; + } + } + } + } + + /* Apply epilogue */ + template + METAL_FUNC void apply_epilogue_safe( + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const BinaryEpilogue& epilogue_op) { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread auto& r = Ctile_r.frag_at(i, j); + thread auto& im = Ctile_i.frag_at(i, j); + int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + complex64_t tmp[kelems]; + + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x && + (i * TM_stride) < dst_tile_dims.y) { + tmp[k] = C[offset_c + k * fdc]; + } else { + tmp[k] = complex64_t(0.0f, 0.0f); + } + } + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]); + r[k] = out.real; + im[k] = out.imag; + } + } + } + } + + /* Store results from simdgroup_matrix results into device memory */ + METAL_FUNC void store_result( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + + // Loop over all simdgroup tiles + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int off_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + D[off_d + k] = + epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]); + } + } + } + } + + METAL_FUNC void store_result_safe( + device U* D, + const int ldd, + const device U* C, + const int ldc, + const int fdc, + short2 dst_tile_dims, + thread const Epilogue& epilogue_op) const { + // Adjust for simdgroup and thread location + C += (sm)*ldc + (sn)*fdc; + D += (sm)*ldd + sn; + dst_tile_dims -= short2(sn, sm); + + if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) + return; + + constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; + + STEEL_PRAGMA_UNROLL + for (int i = 0; i < TM; i++) { + if (i * TM_stride < dst_tile_dims.y) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < TN; j++) { + // Get accumulated result and associated offset in Cr, Ci + thread const auto& r = Ctile_r.frag_at(i, j); + thread const auto& im = Ctile_i.frag_at(i, j); + int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; + int off_d = (i * TM_stride) * ldd + (j * TN_stride); + + // Apply epilogue + STEEL_PRAGMA_UNROLL + for (short k = 0; k < kelems; k++) { + if ((j * TN_stride + k) < dst_tile_dims.x) { + D[off_d + k] = epilogue_op.apply( + complex64_t(r[k], im[k]), C[off_c + k * fdc]); + } + } + } + } + } + } +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/nax.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/nax.h new file mode 100644 index 0000000..5839176 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/nax.h @@ -0,0 +1,1084 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +#include "mlx/backend/metal/kernels/steel/defines.h" +#include "mlx/backend/metal/kernels/steel/gemm/transforms.h" +#include "mlx/backend/metal/kernels/steel/utils/integral_constant.h" + +#include + +using namespace metal; + +/////////////////////////////////////////////////////////////////////////////// +// MMA helper +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// NAX Steel with new tiles +/////////////////////////////////////////////////////////////////////////////// + +struct BaseNAXFrag { + STEEL_CONST short kFragRows = 16; + STEEL_CONST short kFragCols = 16; + + STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; + + STEEL_CONST short kElemRows = 2; + STEEL_CONST short kElemCols = 4; + + STEEL_CONST short kElemRowsJump = 8; + + static_assert( + kElemRows * kElemCols == kElemsPerFrag, + "MMAFrag shape is not consistent with MMAFrag size"); + + template + using dtype_frag_t = typename metal::vec; + + METAL_FUNC static short2 get_coord() { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; + return short2{fn, fm}; + } + + METAL_FUNC static short2 get_coord(short idx) { + const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); + const short qid = simd_lane_id >> 2; + const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; + const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; + return short2{fn, fm}; + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_rows( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } + } + + } else { + dst = dtype_frag_t(0); + } + } + } + + template < + typename T, + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void load_safe( + thread dtype_frag_t& dst, + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[i * kElemCols + j] = + static_cast(src[r * str_x + (c + j) * str_y]); + } else { + dst[i * kElemCols + j] = T(0); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_rows( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + if (r < lim_x) { + if constexpr (metal::is_same_v>) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_safe( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + const auto r = off_x + i * kElemRowsJump + sc.y; + const auto c = off_y + sc.x; + + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + if (r < lim_x && (c + j) < lim_y) { + dst[r * str_x + (c + j) * str_y] = + static_cast(src[i * kElemCols + j]); + } + } + } + } + + template < + typename T, + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC static constexpr void store_slice( + const thread dtype_frag_t& src, + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) { + using U = pointer_element_t; + + const short2 sc = get_coord(); + + const_for_loop<0, kElemRows, 1>([&](auto idx_row) { + const auto r = off_x + idx_row * Int{}; + if (r >= stop_x - sc.y || r < start_x - sc.y) { + return; + } + + const_for_loop<0, kElemCols, 1>([&](auto idx_col) { + const auto c = off_y + idx_col; + if (c >= stop_y - sc.x || c < start_y - sc.x) { + return; + } + + const auto src_idx = idx_row * Int{} + idx_col; + dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = + static_cast(src[src_idx]); + }); + }); + } + + template + METAL_FUNC static constexpr void row_reduce( + thread const dtype_frag_t& inp_vals, + thread T* reduced_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + T thr_reduce = Op::apply( + Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), + Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); + + T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); + qgr_reduce = Op::apply(thr_reduce, qgr_reduce); + + T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); + sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); + + reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); + } + } + + template + METAL_FUNC static constexpr void row_bin_op( + thread dtype_frag_t& inp_vals, + thread T* row_vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kElemRows; i++) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kElemCols; j++) { + inp_vals[i * kElemCols + j] = + Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); + } + } + } +}; + +template < + typename T, + short kRows_, + short kCols_, + typename NAXFrag_t = BaseNAXFrag> +struct NAXSubTile { + STEEL_CONST short kRows = kRows_; + STEEL_CONST short kCols = kCols_; + + STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; + STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; + STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; + + STEEL_CONST short kSubTileRows = kRows / kFragRows; + STEEL_CONST short kSubTileCols = kCols / kFragCols; + + STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; + STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; + + STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; + STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; + + STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; + STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; + STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; + + using frag_type = typename NAXFrag_t::template dtype_frag_t; + + frag_type val_frags[kNumFrags]; + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kNumFrags; ++i) { + val_frags[i] = frag_type(0); + } + } + + METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC constexpr const thread frag_type& frag_at( + const short i, + const short j) const { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr thread frag_type& frag_at() { + return val_frags[i * kSubTileCols + j]; + } + + template + METAL_FUNC constexpr const thread frag_type& frag_at() const { + return val_frags[i * kSubTileCols + j]; + } + + METAL_FUNC thread T* elems() { + return reinterpret_cast(val_frags); + } + + METAL_FUNC const thread T* elems() const { + return reinterpret_cast(val_frags); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_reduce( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::template row_bin_op( + frag_at(i, j), &vals[i * kFragThrRows]); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load( + SrcPtrType src, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load( + frag_at(i, j), + src, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store( + DstPtrType dst, + StrX str_x, + StrY str_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store( + frag_at(i, j), + dst, + str_x, + str_y, + off_x + i * kFragRows, + off_y + j * kFragCols); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_rows( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_rows( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename SrcPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void load_safe( + SrcPtrType src, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::load_safe( + frag_at(i, j), + src, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename LimY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_safe( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + LimY lim_y, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + lim_y, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename LimX, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_rows( + DstPtrType dst, + StrX str_x, + StrY str_y, + LimX lim_x, + OffX off_x = {}, + OffY off_y = {}) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kSubTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kSubTileCols; ++j) { + NAXFrag_t::store_safe( + frag_at(i, j), + dst, + str_x, + str_y, + lim_x, + off_x + (i * kFragRows), + off_y + (j * kFragCols)); + } + } + } + + template < + typename DstPtrType, + typename StrX, + typename StrY, + typename StartX, + typename StopX, + typename StartY, + typename StopY, + typename OffX = Int<0>, + typename OffY = Int<0>> + METAL_FUNC constexpr void store_slice( + DstPtrType dst, + StrX str_x, + StrY str_y, + StartX start_x, + StopX stop_x, + StartY start_y, + StopY stop_y, + OffX off_x = Int<0>{}, + OffY off_y = Int<0>{}) const { + const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { + NAXFrag_t::store_slice( + frag_at(), + dst, + str_x, + str_y, + start_x, + stop_x, + start_y, + stop_y, + off_x + idx_row * Int{}, + off_y + idx_col * Int{}); + }); + }); + } +}; + +template < + short RC, + short CC, + short RA, + short CA, + short RB, + short CB, + typename CType, + typename AType, + typename BType, + bool transpose_a, + bool transpose_b, + typename NAXFrag_t = BaseNAXFrag> +METAL_FUNC void subtile_matmad_nax( + thread NAXSubTile& C, + thread NAXSubTile& A, + metal::bool_constant, + thread NAXSubTile& B, + metal::bool_constant) { + // Static checks + constexpr short FMa = transpose_a ? CA : RA; + constexpr short FMc = RC; + static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); + + constexpr short FNb = transpose_b ? RB : CB; + constexpr short FNc = CC; + static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); + + constexpr short FKa = transpose_a ? RA : CA; + constexpr short FKb = transpose_b ? CB : RB; + static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); + + constexpr short FM = FMc; + constexpr short FN = FNc; + constexpr short FK = FKa; + + constexpr int TM = FM / 16; + constexpr int TN = FN / 16; + constexpr int TK = FK / 16; + + // Create Matmul descriptor + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( + FM, + FN, + FK, + transpose_a, + transpose_b, + true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); + + // Create matmul op + mpp::tensor_ops::matmul2d gemm_op; + + // Create matmul operands in registers + auto ct_a = + gemm_op.template get_left_input_cooperative_tensor(); + auto ct_b = + gemm_op + .template get_right_input_cooperative_tensor(); + + // Create matmul output in register + auto ct_c = gemm_op.template get_destination_cooperative_tensor< + decltype(ct_a), + decltype(ct_b), + CType>(); + + // Load A in to left operand registers + STEEL_PRAGMA_UNROLL + for (short mm = 0; mm < TM; mm++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_a ? kk : mm; + const short fj = transpose_a ? mm : kk; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; + } + } + } + + // Load B into right operand registers + STEEL_PRAGMA_UNROLL + for (short nn = 0; nn < TN; nn++) { + STEEL_PRAGMA_UNROLL + for (short kk = 0; kk < TK; kk++) { + const short fi = transpose_b ? nn : kk; + const short fj = transpose_b ? kk : nn; + + STEEL_PRAGMA_UNROLL + for (short i = 0; i < 8; i++) { + ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; + } + } + } + + // Load C into output registers (op handles accumulation) + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + ct_c[i] = C.elems()[i]; + } + + // Do matmul + gemm_op.run(ct_a, ct_b, ct_c); + + // Copy out results + STEEL_PRAGMA_UNROLL + for (short i = 0; i < ct_c.get_capacity(); i++) { + C.elems()[i] = ct_c[i]; + } +} + +template +struct NAXTile { + using NAXSubTile_t = NAXSubTile_; + using elem_type = T; + STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; + STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; + STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; + + STEEL_CONST short kTileRows = kTileRows_; + STEEL_CONST short kTileCols = kTileCols_; + + STEEL_CONST short kRows = kTileRows * kSubTileRows; + STEEL_CONST short kCols = kTileCols * kSubTileCols; + + STEEL_CONST short kSubTiles = kTileRows * kTileCols; + STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; + + STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; + + STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; + STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; + + NAXSubTile_t val_subtiles[kSubTiles]; + + METAL_FUNC NAXTile() thread {} + + METAL_FUNC constexpr void clear() { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kSubTiles; ++i) { + val_subtiles[i].clear(); + } + } + + METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( + const short i, + const short j) { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( + const short i, + const short j) const { + return val_subtiles[i * kTileCols + j]; + } + + template + METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { + return val_subtiles[i * kTileCols + j]; + } + + METAL_FUNC thread elem_type* elems() { + return reinterpret_cast(val_subtiles[0].elems()); + } + + METAL_FUNC const thread elem_type* elems() const { + return reinterpret_cast(val_subtiles[0].elems()); + } + + template + METAL_FUNC void row_reduce(thread metal::vec& vals) const { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_reduce(sub_rows[i]); + } + } + } + + template + METAL_FUNC void row_bin_op(thread metal::vec& vals) { + auto sub_rows = (thread metal::vec*)(&vals); + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).template row_bin_op(sub_rows[i]); + } + } + } + + template + METAL_FUNC void load(const threadgroup U* src) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + src, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store(threadgroup U* dst) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + dst, + Int{}, + Int{}, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void load(const device U* src, const int ld) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load( + &src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void store(device U* dst, const int ld) const { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store( + &dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); + } + } + } + + template + METAL_FUNC void + load_rows(const device U* src, const int ld, const short n_rows) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_rows( + &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + load_safe(const device U* src, const int ld, const short2 src_tile_dims) { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).load_safe( + src, + ld, + Int<1>{}, + src_tile_dims.y, + src_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) + const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_rows( + &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], + ld, + Int<1>{}, + n_rows - i * kSubTileRows); + } + } + } + + template + METAL_FUNC void + store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { + STEEL_PRAGMA_UNROLL + for (int i = 0; i < kTileRows; ++i) { + STEEL_PRAGMA_UNROLL + for (int j = 0; j < kTileCols; ++j) { + subtile_at(i, j).store_safe( + dst, + ld, + Int<1>{}, + dst_tile_dims.y, + dst_tile_dims.x, + i * kSubTileRows, + j * kSubTileCols); + } + } + } + + template + METAL_FUNC void store_slice( + device U* dst, + const int ld, + const short2 start, + const short2 stop) const { + const_for_loop<0, kTileRows, 1>([&](auto idx_row) { + const_for_loop<0, kTileCols, 1>([&](auto idx_col) { + subtile_at().store_slice( + dst, + ld, + Int<1>{}, + start.y, + stop.y, + start.x, + stop.x, + idx_row * Int{}, + idx_col * Int{}); + }); + }); + } +}; + +template < + class CTile, + class ATile, + class BTile, + bool transpose_a, + bool transpose_b> +METAL_FUNC void tile_matmad_nax( + thread CTile& C, + thread ATile& A, + metal::bool_constant, + thread BTile& B, + metal::bool_constant) { + // Static checks + constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; + constexpr short TMc = CTile::kTileRows; + static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); + + constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; + constexpr short FMc = CTile::kSubTileRows; + static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); + + constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; + constexpr short TNc = CTile::kTileCols; + static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); + + constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; + constexpr short FNc = CTile::kSubTileCols; + static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); + + constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; + constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; + static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); + + constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; + constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; + static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); + + constexpr short TM = TMc; + constexpr short TN = TNc; + constexpr short TK = TKa; + + // Do matmul here + STEEL_PRAGMA_UNROLL + for (short i = 0; i < TM; ++i) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < TN; ++j) { + STEEL_PRAGMA_UNROLL + for (short k = 0; k < TK; ++k) { + const short ra = transpose_a ? k : i; + const short ca = transpose_a ? i : k; + const short rb = transpose_b ? j : k; + const short cb = transpose_b ? k : j; + + subtile_matmad_nax( + C.subtile_at(i, j), + A.subtile_at(ra, ca), + metal::bool_constant{}, + B.subtile_at(rb, cb), + metal::bool_constant{}); + } + } + } +} + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/params.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/params.h new file mode 100644 index 0000000..3cb7bdc --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/params.h @@ -0,0 +1,64 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +/////////////////////////////////////////////////////////////////////////////// +// GEMM param classes +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +struct GEMMParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldd; + + const int tiles_n; + const int tiles_m; + + const int64_t batch_stride_a; + const int64_t batch_stride_b; + const int64_t batch_stride_d; + + const int swizzle_log; + const int gemm_k_iterations_aligned; + + const int batch_ndim; +}; + +struct GEMMSpiltKParams { + const int M; + const int N; + const int K; + + const int lda; + const int ldb; + const int ldc; + + const int tiles_n; + const int tiles_m; + + const int split_k_partitions; + const int split_k_partition_stride; + const int split_k_partition_size; + + const int gemm_k_iterations_aligned; +}; + +struct GEMMAddMMParams { + const int ldc; + const int fdc; + + const int64_t batch_stride_c; + + const float alpha; + const float beta; +}; + +} // namespace steel +} // namespace mlx diff --git a/dist/include/mlx/backend/metal/kernels/steel/gemm/transforms.h b/dist/include/mlx/backend/metal/kernels/steel/gemm/transforms.h new file mode 100644 index 0000000..0282a12 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/gemm/transforms.h @@ -0,0 +1,72 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/kernels/steel/utils.h" + +/////////////////////////////////////////////////////////////////////////////// +// Transforms and Epilogues +/////////////////////////////////////////////////////////////////////////////// + +namespace mlx { +namespace steel { + +template +struct TransformNone { + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT) { + return static_cast(x); + } +}; + +template +struct TransformAdd { + TransformAdd(const float, const float) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + static METAL_FUNC OutT apply(InT x, OutT c) { + return static_cast(x) + c; + } +}; + +template +struct TransformAxpby { + const float alpha; + const float beta; + + TransformAxpby(const float alpha_, const float beta_) + : alpha(alpha_), beta(beta_) {} + + static METAL_FUNC OutT apply(InT x) { + return static_cast(x); + } + + METAL_FUNC OutT apply(InT x, OutT c) const { + return static_cast( + x * static_cast(alpha) + (static_cast(beta) * c)); + } +}; + +template +struct AccumHelper { + typedef float accum_type; +}; + +struct BlockSwizzle { + static METAL_FUNC int2 + swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { + const int tid_x = (tid.x) >> swizzle_log; + const int tid_y = + ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); + return int2(tid_x, tid_y); + } +}; + +} // namespace steel +} // namespace mlx \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/steel/utils.h b/dist/include/mlx/backend/metal/kernels/steel/utils.h new file mode 100644 index 0000000..55720a2 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/utils.h @@ -0,0 +1,42 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +METAL_FUNC ulong2 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + } + return ulong2(loc_a, loc_b); +} + +METAL_FUNC ulong3 elem_to_loc_broadcast( + uint elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, + int ndim) { + ulong loc_a{0}; + ulong loc_b{0}; + ulong loc_c{0}; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + int pos_in_dim = (elem % shape[i]); + elem /= shape[i]; + loc_a += pos_in_dim * a_strides[i]; + loc_b += pos_in_dim * b_strides[i]; + loc_c += pos_in_dim * c_strides[i]; + } + return ulong3(loc_a, loc_b, loc_c); +} diff --git a/dist/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h b/dist/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h new file mode 100644 index 0000000..526f561 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/utils/integral_constant.h @@ -0,0 +1,134 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include "mlx/backend/metal/kernels/steel/utils/type_traits.h" + +#pragma METAL internals : enable + +namespace mlx { +namespace steel { + +/////////////////////////////////////////////////////////////////////////////// +// Integral constant with casting +/////////////////////////////////////////////////////////////////////////////// + +template +struct integral_constant { + static constexpr constant T value = v; + using value_type = T; + using type = integral_constant; + + METAL_FUNC constexpr operator value_type() const noexcept { + return value; + } + + // METAL_FUNC constexpr value_type operator()() const noexcept { + // return value; + // } +}; + +template +using bool_constant = integral_constant; +using true_type = bool_constant; +using false_type = bool_constant; + +template +struct is_integral : bool_constant::value> {}; + +template +struct is_integral> + : bool_constant::value> {}; + +template +constexpr constant bool is_integral_v = is_integral::value; + +template +using Int = integral_constant; + +/////////////////////////////////////////////////////////////////////////////// +// Binary Operators on Integral constants +/////////////////////////////////////////////////////////////////////////////// + +#define integral_const_binop(__op__, __operator__) \ + template \ + METAL_FUNC constexpr auto __operator__( \ + integral_constant, integral_constant) { \ + constexpr auto res = tv __op__ uv; \ + return integral_constant{}; \ + } + +integral_const_binop(+, operator+); +integral_const_binop(-, operator-); +integral_const_binop(*, operator*); +integral_const_binop(/, operator/); + +integral_const_binop(==, operator==); +integral_const_binop(!=, operator!=); +integral_const_binop(<, operator<); +integral_const_binop(>, operator>); +integral_const_binop(<=, operator<=); +integral_const_binop(>=, operator>=); + +integral_const_binop(&&, operator&&); +integral_const_binop(||, operator||); + +template >> +METAL_FUNC constexpr auto operator||(true_type, T) { + return true_type{}; +} +template >> +METAL_FUNC constexpr auto operator||(T, true_type) { + return true_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(false_type, T) { + return false_type{}; +} + +template >> +METAL_FUNC constexpr auto operator&&(T, false_type) { + return false_type{}; +} + +// Dispatch utilities +template +void dispatch_bool(bool v, F f) { + if (v) { + f(true_type{}); + } else { + f(false_type{}); + } +} + +template +constexpr void const_for_loop(F f) { + if constexpr (start < stop) { + constexpr auto idx = Int{}; + f(idx); + const_for_loop(f); + } +} + +#undef integral_const_binop + +/////////////////////////////////////////////////////////////////////////////// +// Reduction operators +/////////////////////////////////////////////////////////////////////////////// + +template +METAL_FUNC constexpr T sum(T x) { + return x; +} + +template +METAL_FUNC constexpr auto sum(T x, Us... us) { + return x + sum(us...); +} + +} // namespace steel +} // namespace mlx + +#pragma METAL internals : disable \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/steel/utils/type_traits.h b/dist/include/mlx/backend/metal/kernels/steel/utils/type_traits.h new file mode 100644 index 0000000..f004dc8 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/steel/utils/type_traits.h @@ -0,0 +1,55 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#pragma METAL internals : enable + +namespace metal { + +template +struct is_empty : metal::bool_constant<__is_empty(T)> {}; + +#ifdef __cpp_variable_templates +template +constexpr constant bool is_empty_v = is_empty::value; +#endif + +template +struct make_void { + typedef void type; +}; + +template +using void_t = typename make_void::type; + +template +struct is_static : metal::bool_constant>::value> {}; + +template +struct pointer_element {}; + +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; +template +struct pointer_element { + using type = remove_cv_t; +}; + +template +using pointer_element_t = typename pointer_element>::type; + +} // namespace metal + +#pragma METAL internals : disable \ No newline at end of file diff --git a/dist/include/mlx/backend/metal/kernels/ternary.h b/dist/include/mlx/backend/metal/kernels/ternary.h new file mode 100644 index 0000000..705b73e --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/ternary.h @@ -0,0 +1,145 @@ +// Copyright © 2024 Apple Inc. + +template < + typename T, + typename Op, + bool BSCALAR, + bool CSCALAR, + int N = WorkPerThread::n> +[[kernel]] void ternary_v( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto bidx = BSCALAR ? 0 : index + i; + auto cidx = CSCALAR ? 0 : index + i; + d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); + } + } else { + for (int i = 0; i < N; ++i) { + auto bidx = BSCALAR ? 0 : index + i; + auto cidx = CSCALAR ? 0 : index + i; + d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); + } + } +} + +template < + typename T, + typename Op, + bool BSCALAR, + bool CSCALAR, + int N = WorkPerThread::n> +[[kernel]] void ternary_v2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto bidx = BSCALAR ? 0 : offset + i; + auto cidx = CSCALAR ? 0 : offset + i; + d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); + } + } else { + for (int i = 0; i < N; ++i) { + auto bidx = BSCALAR ? 0 : offset + i; + auto cidx = CSCALAR ? 0 : offset + i; + d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); + } + } +} + +template +[[kernel]] void ternary_g_nd1( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int64_t& a_strides, + constant const int64_t& b_strides, + constant const int64_t& c_strides, + uint index [[thread_position_in_grid]]) { + auto a_idx = elem_to_loc_1(index, a_strides); + auto b_idx = elem_to_loc_1(index, b_strides); + auto c_idx = elem_to_loc_1(index, c_strides); + d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd2( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int64_t a_strides[2], + constant const int64_t b_strides[2], + constant const int64_t c_strides[2], + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_2(index, a_strides); + auto b_idx = elem_to_loc_2(index, b_strides); + auto c_idx = elem_to_loc_2(index, c_strides); + IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g_nd3( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int64_t a_strides[3], + constant const int64_t b_strides[3], + constant const int64_t c_strides[3], + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto a_idx = elem_to_loc_3(index, a_strides); + auto b_idx = elem_to_loc_3(index, b_strides); + auto c_idx = elem_to_loc_3(index, c_strides); + IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); + d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); +} + +template +[[kernel]] void ternary_g( + device const bool* a, + device const T* b, + device const T* c, + device T* d, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, + constant const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc_3_nd( + {N * index.x, index.y, index.z}, + shape, + a_strides, + b_strides, + c_strides, + ndim); + auto xshape = shape[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + IdxT a_xstride = a_strides[ndim - 1]; + IdxT b_xstride = b_strides[ndim - 1]; + IdxT c_xstride = c_strides[ndim - 1]; + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); + idx.x += a_xstride; + idx.y += b_xstride; + idx.z += c_xstride; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/ternary_ops.h b/dist/include/mlx/backend/metal/kernels/ternary_ops.h new file mode 100644 index 0000000..e0235d9 --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/ternary_ops.h @@ -0,0 +1,10 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +struct Select { + template + T operator()(bool condition, T x, T y) { + return condition ? x : y; + } +}; diff --git a/dist/include/mlx/backend/metal/kernels/unary.h b/dist/include/mlx/backend/metal/kernels/unary.h new file mode 100644 index 0000000..db7be3d --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/unary.h @@ -0,0 +1,63 @@ +// Copyright © 2024 Apple Inc. + +template ::n> +[[kernel]] void unary_v( + device const T* in, + device U* out, + constant uint& size, + uint index [[thread_position_in_grid]]) { + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = static_cast(Op()(in[index + i])); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = static_cast(Op()(in[index + i])); + } + } +} + +template ::n> +[[kernel]] void unary_v2( + device const T* in, + device U* out, + constant int64_t& size, + uint2 index [[thread_position_in_grid]], + uint2 grid_dim [[threads_per_grid]]) { + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = static_cast(Op()(in[offset + i])); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = static_cast(Op()(in[offset + i])); + } + } +} + +template < + typename T, + typename U, + typename Op, + int N = 1, + typename IdxT = int64_t> +[[kernel]] void unary_g( + device const T* in, + device U* out, + constant const int* in_shape, + constant const int64_t* in_strides, + device const int& ndim, + uint3 index [[thread_position_in_grid]], + uint3 grid_dim [[threads_per_grid]]) { + auto idx = elem_to_loc( + {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); + auto xshape = in_shape[ndim - 1]; + IdxT xstride = in_strides[ndim - 1]; + IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); + for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { + out[out_idx++] = static_cast(Op()(in[idx])); + idx += xstride; + } +} diff --git a/dist/include/mlx/backend/metal/kernels/unary_ops.h b/dist/include/mlx/backend/metal/kernels/unary_ops.h new file mode 100644 index 0000000..327bb5a --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/unary_ops.h @@ -0,0 +1,454 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/metal/kernels/cexpf.h" +#include "mlx/backend/metal/kernels/erf.h" +#include "mlx/backend/metal/kernels/expm1f.h" +#include "mlx/backend/metal/kernels/fp8.h" + +namespace { +constant float inf = metal::numeric_limits::infinity(); +} + +struct Abs { + template + T operator()(T x) { + return metal::abs(x); + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; + complex64_t operator()(complex64_t x) { + return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; + }; +}; + +struct ArcCos { + template + T operator()(T x) { + return metal::precise::acos(x); + }; + + complex64_t operator()(complex64_t x); +}; + +struct ArcCosh { + template + T operator()(T x) { + return metal::precise::acosh(x); + }; +}; + +struct ArcSin { + template + T operator()(T x) { + return metal::precise::asin(x); + }; + + complex64_t operator()(complex64_t x); +}; + +struct ArcSinh { + template + T operator()(T x) { + return metal::precise::asinh(x); + }; +}; + +struct ArcTan { + template + T operator()(T x) { + return metal::precise::atan(x); + }; + + complex64_t operator()(complex64_t x); +}; + +struct ArcTanh { + template + T operator()(T x) { + return metal::precise::atanh(x); + }; +}; + +struct BitwiseInvert { + template + T operator()(T x) { + return ~x; + }; +}; + +struct Ceil { + template + T operator()(T x) { + return metal::ceil(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Cos { + template + T operator()(T x) { + return metal::precise::cos(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::cos(x.real) * metal::precise::cosh(x.imag), + -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Cosh { + template + T operator()(T x) { + return metal::precise::cosh(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::cosh(x.real) * metal::precise::cos(x.imag), + metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Conjugate { + complex64_t operator()(complex64_t x) { + return complex64_t{x.real, -x.imag}; + } +}; + +struct Erf { + template + T operator()(T x) { + return static_cast(erf(static_cast(x))); + }; +}; + +struct ErfInv { + template + T operator()(T x) { + return static_cast(erfinv(static_cast(x))); + }; +}; + +struct Exp { + template + T operator()(T x) { + return metal::precise::exp(x); + }; + complex64_t operator()(complex64_t x) { + return cexpf(x); + } +}; + +struct Expm1 { + template + T operator()(T x) { + return static_cast(expm1f(static_cast(x))); + }; +}; + +struct Floor { + template + T operator()(T x) { + return metal::floor(x); + }; + int8_t operator()(int8_t x) { + return x; + }; + int16_t operator()(int16_t x) { + return x; + }; + int32_t operator()(int32_t x) { + return x; + }; + int64_t operator()(int64_t x) { + return x; + }; + uint8_t operator()(uint8_t x) { + return x; + }; + uint16_t operator()(uint16_t x) { + return x; + }; + uint32_t operator()(uint32_t x) { + return x; + }; + uint64_t operator()(uint64_t x) { + return x; + }; + bool operator()(bool x) { + return x; + }; +}; + +struct Imag { + float operator()(complex64_t x) { + return x.imag; + }; +}; + +struct Log { + template + T operator()(T x) { + return metal::precise::log(x); + }; + + complex64_t operator()(complex64_t x) { + auto r = metal::precise::log(Abs{}(x).real); + auto i = metal::precise::atan2(x.imag, x.real); + return {r, i}; + }; +}; + +struct Log2 { + template + T operator()(T x) { + return metal::precise::log2(x); + }; + + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN2_F, y.imag / M_LN2_F}; + }; +}; + +struct Log10 { + template + T operator()(T x) { + return metal::precise::log10(x); + }; + + complex64_t operator()(complex64_t x) { + auto y = Log{}(x); + return {y.real / M_LN10_F, y.imag / M_LN10_F}; + }; +}; + +struct Log1p { + template + T operator()(T x) { + return log1p(x); + }; +}; + +struct LogicalNot { + template + T operator()(T x) { + return !x; + }; +}; + +struct Negative { + template + T operator()(T x) { + return -x; + }; +}; + +struct Real { + float operator()(complex64_t x) { + return x.real; + }; +}; + +struct Round { + template + T operator()(T x) { + return metal::rint(x); + }; + complex64_t operator()(complex64_t x) { + return {metal::rint(x.real), metal::rint(x.imag)}; + }; +}; + +struct Sigmoid { + template + T operator()(T x) { + auto y = 1 / (1 + metal::exp(metal::abs(x))); + return (x < 0) ? y : 1 - y; + } +}; + +struct Sign { + template + T operator()(T x) { + return (x > T(0)) - (x < T(0)); + }; + uint32_t operator()(uint32_t x) { + return x != 0; + }; + complex64_t operator()(complex64_t x) { + if (x == complex64_t(0)) { + return x; + } + return x / + (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); + }; +}; + +struct Sin { + template + T operator()(T x) { + return metal::precise::sin(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::sin(x.real) * metal::precise::cosh(x.imag), + metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; + }; +}; + +struct Sinh { + template + T operator()(T x) { + return metal::precise::sinh(x); + }; + + complex64_t operator()(complex64_t x) { + return { + metal::precise::sinh(x.real) * metal::precise::cos(x.imag), + metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; + }; +}; + +struct Square { + template + T operator()(T x) { + return x * x; + }; +}; + +struct Sqrt { + template + T operator()(T x) { + return metal::precise::sqrt(x); + }; + + complex64_t operator()(complex64_t x) { + if (x.real == 0.0 && x.imag == 0.0) { + return {0.0, 0.0}; + } + auto r = Abs{}(x).real; + auto a = metal::precise::sqrt((r + x.real) / 2.0); + auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); + auto b = metal::copysign(b_abs, x.imag); + return {a, b}; + } +}; + +struct Rsqrt { + template + T operator()(T x) { + return metal::precise::rsqrt(x); + }; + + complex64_t operator()(complex64_t x) { + return 1.0 / Sqrt{}(x); + } +}; + +struct Tan { + template + T operator()(T x) { + return metal::precise::tan(x); + }; + + complex64_t operator()(complex64_t x) { + float tan_a = metal::precise::tan(x.real); + float tanh_b = metal::precise::tanh(x.imag); + float t1 = tan_a * tanh_b; + float denom = 1. + t1 * t1; + return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; + }; +}; + +struct Tanh { + template + T operator()(T x) { + return metal::precise::tanh(x); + }; + + complex64_t operator()(complex64_t x) { + float tanh_a = metal::precise::tanh(x.real); + float tan_b = metal::precise::tan(x.imag); + float t1 = tanh_a * tan_b; + float denom = 1. + t1 * t1; + return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; + }; +}; + +complex64_t ArcCos::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcSin::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); + return {y.imag, -y.real}; +}; + +complex64_t ArcTan::operator()(complex64_t x) { + auto i = complex64_t{0.0, 1.0}; + auto ix = i * x; + return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); +}; + +struct ToFP8 { + template + uint8_t operator()(T f) { + return fp8_e4m3(f).bits; + } +}; + +struct FromFP8 { + float operator()(uint8_t x) { + return float(*(thread fp8_e4m3*)(&x)); + } +}; diff --git a/dist/include/mlx/backend/metal/kernels/utils.h b/dist/include/mlx/backend/metal/kernels/utils.h new file mode 100644 index 0000000..acdbc6a --- /dev/null +++ b/dist/include/mlx/backend/metal/kernels/utils.h @@ -0,0 +1,444 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/backend/metal/kernels/bf16.h" +#include "mlx/backend/metal/kernels/bf16_math.h" +#include "mlx/backend/metal/kernels/complex.h" +#include "mlx/backend/metal/kernels/defines.h" + +typedef half float16_t; + +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + +/////////////////////////////////////////////////////////////////////////////// +// Type limits utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct Limits { + static const constant U max = metal::numeric_limits::max(); + static const constant U min = metal::numeric_limits::min(); + static const constant U finite_max = metal::numeric_limits::max(); + static const constant U finite_min = metal::numeric_limits::min(); +}; + +#define instantiate_default_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = metal::numeric_limits::max(); \ + static constexpr constant type min = metal::numeric_limits::min(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + metal::numeric_limits::min(); \ + }; + +instantiate_default_limit(uint8_t); +instantiate_default_limit(uint16_t); +instantiate_default_limit(uint32_t); +instantiate_default_limit(uint64_t); +instantiate_default_limit(int8_t); +instantiate_default_limit(int16_t); +instantiate_default_limit(int32_t); +instantiate_default_limit(int64_t); + +#define instantiate_float_limit(type) \ + template <> \ + struct Limits { \ + static constexpr constant type max = \ + metal::numeric_limits::infinity(); \ + static constexpr constant type min = \ + -metal::numeric_limits::infinity(); \ + static constexpr constant type finite_max = \ + metal::numeric_limits::max(); \ + static constexpr constant type finite_min = \ + -metal::numeric_limits::max(); \ + }; + +instantiate_float_limit(half); +instantiate_float_limit(float); +instantiate_float_limit(bfloat16_t); + +template <> +struct Limits { + static constexpr constant bool max = true; + static constexpr constant bool min = false; +}; + +template <> +struct Limits { + static constexpr constant complex64_t max = complex64_t( + metal::numeric_limits::infinity(), + metal::numeric_limits::infinity()); + static constexpr constant complex64_t min = complex64_t( + -metal::numeric_limits::infinity(), + -metal::numeric_limits::infinity()); +}; + +/////////////////////////////////////////////////////////////////////////////// +// Indexing utils +/////////////////////////////////////////////////////////////////////////////// + +#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with generic dims + +template +METAL_FUNC IdxT elem_to_loc( + IdxT elem, + constant const int* shape, + constant const int64_t* strides, + int ndim) { + IdxT loc = 0; + for (int i = ndim - 1; i >= 0 && elem > 0; --i) { + loc += (elem % shape[i]) * IdxT(strides[i]); + elem /= shape[i]; + } + return loc; +} + +// Non templated version to handle arbitrary dims +template +METAL_FUNC IdxT elem_to_loc( + uint3 elem, + constant const int* shape, + constant const int64_t* strides, + int ndim) { + IdxT loc = + elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); + for (int d = ndim - 3; d >= 0; --d) { + loc += (elem.z % shape[d]) * IdxT(strides[d]); + elem.z /= shape[d]; + } + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Single Array with fixed N dims + +template +METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) { + return elem * IdxT(stride); +} + +template +METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) { + return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); +} + +template +METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { + return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + + elem.z * IdxT(strides[0]); +} + +/////////////////////////////////////////////////////////////////////////////// +// Multiple Arrays with generic dims + +template +METAL_FUNC vec elem_to_loc_2_nd( + uint3 elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + int ndim) { + vec loc = { + IdxT( + elem.x * IdxT(a_strides[ndim - 1]) + + IdxT(elem.y) * IdxT(a_strides[ndim - 2])), + IdxT( + elem.x * IdxT(b_strides[ndim - 1]) + + elem.y * IdxT(b_strides[ndim - 2]))}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); + elem.z /= shape[d]; + } + return loc; +} + +template +METAL_FUNC vec elem_to_loc_3_nd( + uint3 elem, + constant const int* shape, + constant const int64_t* a_strides, + constant const int64_t* b_strides, + constant const int64_t* c_strides, + int ndim) { + vec loc = { + IdxT(elem.x * IdxT(a_strides[ndim - 1])) + + IdxT(elem.y * IdxT(a_strides[ndim - 2])), + IdxT(elem.x * IdxT(b_strides[ndim - 1])) + + IdxT(elem.y * IdxT(b_strides[ndim - 2])), + IdxT(elem.x * IdxT(c_strides[ndim - 1])) + + IdxT(elem.y * IdxT(c_strides[ndim - 2]))}; + for (int d = ndim - 3; d >= 0; --d) { + uint l = elem.z % shape[d]; + loc.x += l * IdxT(a_strides[d]); + loc.y += l * IdxT(b_strides[d]); + loc.z += l * IdxT(c_strides[d]); + elem.z /= shape[d]; + } + return loc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Elem to loc in a loop utils +/////////////////////////////////////////////////////////////////////////////// + +template +struct LoopedElemToLoc { + int dim; + LoopedElemToLoc inner_looper; + OffsetT offset{0}; + int index{0}; + + LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} + + void next(const constant int* shape, const constant int64_t* strides) { + if (dim == 0) { + return; + } + index++; + offset += OffsetT(strides[dim - 1]); + if (index >= shape[dim - 1]) { + index = 0; + inner_looper.next(shape, strides); + offset = inner_looper.offset; + } + } + + void next(int n, const constant int* shape, const constant int64_t* strides) { + if (dim == 0) { + return; + } + index += n; + offset += n * OffsetT(strides[dim - 1]); + + if (index >= shape[dim - 1]) { + int extra = index - shape[dim - 1]; + if (extra >= shape[dim - 1]) { + inner_looper.next(1 + extra / shape[dim - 1], shape, strides); + extra = extra % shape[dim - 1]; + } else { + inner_looper.next(shape, strides); + } + index = 0; + offset = inner_looper.offset; + if (extra > 0) { + next(extra, shape, strides); + } + } + } + + OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, OffsetT, true> { + int dim; + OffsetT offset{0}; + uint index{0}; + + LoopedElemToLoc(int dim) : dim(dim) {} + + void next(const constant int* shape, const constant int64_t* strides) { + index++; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset += OffsetT(strides[0]); + } + } + + void next(int n, const constant int* shape, const constant int64_t* strides) { + index += n; + if (dim > 1) { + offset = elem_to_loc(index, shape, strides, dim); + } else { + offset = index * OffsetT(strides[0]); + } + } + + OffsetT location() { + return offset; + } +}; + +template +struct LoopedElemToLoc<1, OffsetT, false> { + OffsetT offset{0}; + + LoopedElemToLoc(int) {} + + void next(const constant int*, const constant int64_t* strides) { + offset += OffsetT(strides[0]); + } + + void next(int n, const constant int*, const constant int64_t* strides) { + offset += n * OffsetT(strides[0]); + } + + OffsetT location() { + return offset; + } +}; + +/////////////////////////////////////////////////////////////////////////////// +// Calculation utils +/////////////////////////////////////////////////////////////////////////////// + +/** Compute ceil((float)N/(float)M) */ +template +inline T ceildiv(T N, U M) { + return (N + M - 1) / M; +} + +// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 +inline float log1p(float x) { + float xp1 = 1.0f + x; + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return x * (metal::log(xp1) / (xp1 - 1.0f)); +} + +inline bfloat16_t log1p(bfloat16_t x) { + float xp1 = 1.0f + static_cast(x); + if (xp1 == Limits::max) { + return Limits::max; + } + if (xp1 == 1.0f) { + return x; + } + + return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); +} + +inline complex64_t log1p(complex64_t in) { + float x = in.real; + float y = in.imag; + float zabs = metal::precise::sqrt(x * x + y * y); + float theta = metal::atan2(y, x + 1); + if (zabs < 0.5f) { + float r = x * (2 + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {0.5f * log1p(r), theta}; + } else { + auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); + return {metal::log(z0), theta}; + } +} + +/////////////////////////////////////////////////////////////////////////////// +// SIMD shuffle ops +/////////////////////////////////////////////////////////////////////////////// + +inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { + return as_type( + metal::simd_shuffle_down(as_type(data), delta)); +} + +inline bool simd_shuffle_down(bool data, uint16_t delta) { + return simd_shuffle_down(static_cast(data), delta); +} + +inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); +} + +inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { + return as_type(metal::simd_shuffle_up(as_type(data), delta)); +} + +inline bool simd_shuffle_up(bool data, uint16_t delta) { + return simd_shuffle_up(static_cast(data), delta); +} + +inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) { + return complex64_t( + simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta)); +} + +inline uint64_t +simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline int64_t +simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { + return as_type(metal::simd_shuffle_and_fill_up( + as_type(data), as_type(filling), delta)); +} + +inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { + return simd_shuffle_and_fill_up( + static_cast(data), static_cast(filling), delta); +} + +inline complex64_t simd_shuffle_and_fill_up( + complex64_t data, + complex64_t filling, + uint16_t delta) { + return complex64_t( + simd_shuffle_and_fill_up(data.real, filling.real, delta), + simd_shuffle_and_fill_up(data.imag, filling.imag, delta)); +} + +inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline int64_t simd_shuffle(int64_t data, uint16_t lane) { + return as_type(metal::simd_shuffle(as_type(data), lane)); +} + +inline bool simd_shuffle(bool data, uint16_t lane) { + return simd_shuffle(static_cast(data), lane); +} + +inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { + return complex64_t( + simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); +} + +// std::conditional is not included with Metal +template +struct ConditionalType { + using type = U; +}; + +template +struct ConditionalType { + using type = T; +}; diff --git a/dist/include/mlx/backend/metal/matmul.h b/dist/include/mlx/backend/metal/matmul.h new file mode 100644 index 0000000..218664b --- /dev/null +++ b/dist/include/mlx/backend/metal/matmul.h @@ -0,0 +1,144 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/device.h" + +namespace mlx::core { + +template +void steel_matmul_regular_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out, + int64_t C_batch_stride = 0, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul_regular( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + int ldd, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape, + Strides batch_strides, + int64_t A_batch_stride, + int64_t B_batch_stride, + int64_t matrix_stride_out) { + return steel_matmul_regular_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* int ldd = */ ldd, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides batch_strides = */ batch_strides, + /* int64_t A_batch_stride = */ A_batch_stride, + /* int64_t B_batch_stride = */ B_batch_stride, + /* int64_t matrix_stride_out = */ matrix_stride_out); +} + +template +void steel_matmul_axpby( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + const array& c, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}, + Strides C_batch_stride = {}, + float alpha = 1.0f, + float beta = 0.0f); + +inline void steel_matmul( + const Stream& s, + metal::Device& d, + const array& a, + const array& b, + array& out, + int M, + int N, + int K, + int batch_size_out, + int lda, + int ldb, + bool transpose_a, + bool transpose_b, + std::vector& copies, + Shape batch_shape = {}, + Strides A_batch_stride = {}, + Strides B_batch_stride = {}) { + return steel_matmul_axpby( + /* const Stream& s = */ s, + /* metal::Device& d = */ d, + /* const array& a = */ a, + /* const array& b = */ b, + /* const array& c = */ b, + /* array& out = */ out, + /* int M = */ M, + /* int N = */ N, + /* int K = */ K, + /* int batch_size_out = */ batch_size_out, + /* int lda = */ lda, + /* int ldb = */ ldb, + /* bool transpose_a = */ transpose_a, + /* bool transpose_b = */ transpose_b, + /* std::vector& copies = */ copies, + /* Shape batch_shape = */ batch_shape, + /* Strides A_batch_stride = */ A_batch_stride, + /* Strides B_batch_stride = */ B_batch_stride); +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/metal/metal.h b/dist/include/mlx/backend/metal/metal.h new file mode 100644 index 0000000..af2995b --- /dev/null +++ b/dist/include/mlx/backend/metal/metal.h @@ -0,0 +1,22 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::metal { + +/* Check if the Metal backend is available. */ +bool is_available(); + +/** Capture a GPU trace, saving it to an absolute file `path` */ +void start_capture(std::string path = ""); +void stop_capture(); + +/** Get information about the GPU and system settings. */ +const std::unordered_map>& +device_info(); + +} // namespace mlx::core::metal diff --git a/dist/include/mlx/backend/metal/reduce.h b/dist/include/mlx/backend/metal/reduce.h new file mode 100644 index 0000000..a997d7e --- /dev/null +++ b/dist/include/mlx/backend/metal/reduce.h @@ -0,0 +1,41 @@ +// Copyright @ 2023 - 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/common/reduce.h" +#include "mlx/backend/metal/device.h" +#include "mlx/stream.h" + +namespace mlx::core { + +using metal::CommandEncoder; + +void all_reduce_dispatch( + const array& in, + array& out, + const std::string& op_name, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s); + +void row_reduce_general_dispatch( + const array& in, + array& out, + const std::string& op_name, + const ReductionPlan& plan, + const std::vector& axes, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s); + +void strided_reduce_general_dispatch( + const array& in, + array& out, + const std::string& op_name, + const ReductionPlan& plan, + const std::vector& axes, + CommandEncoder& compute_encoder, + metal::Device& d, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/metal/resident.h b/dist/include/mlx/backend/metal/resident.h new file mode 100644 index 0000000..5db5582 --- /dev/null +++ b/dist/include/mlx/backend/metal/resident.h @@ -0,0 +1,32 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/backend/metal/device.h" + +namespace mlx::core::metal { + +class ResidencySet { + public: + ResidencySet(MTL::Device* d); + ~ResidencySet(); + + ResidencySet(const ResidencySet&) = delete; + ResidencySet& operator=(const ResidencySet&) = delete; + + const MTL::ResidencySet* mtl_residency_set() { + return wired_set_; + } + + void insert(MTL::Allocation* buf); + void erase(MTL::Allocation* buf); + + void resize(size_t size); + + private: + MTL::ResidencySet* wired_set_{nullptr}; + std::unordered_set unwired_set_; + size_t capacity_{0}; +}; + +} // namespace mlx::core::metal diff --git a/dist/include/mlx/backend/metal/scan.h b/dist/include/mlx/backend/metal/scan.h new file mode 100644 index 0000000..dab79c5 --- /dev/null +++ b/dist/include/mlx/backend/metal/scan.h @@ -0,0 +1,17 @@ +#pragma once + +#include "mlx/array.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +void scan_gpu_inplace( + array in, + array& out, + Scan::ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/metal/ternary.h b/dist/include/mlx/backend/metal/ternary.h new file mode 100644 index 0000000..91c6fbb --- /dev/null +++ b/dist/include/mlx/backend/metal/ternary.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void ternary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +void ternary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/metal/unary.h b/dist/include/mlx/backend/metal/unary.h new file mode 100644 index 0000000..1d6ecf0 --- /dev/null +++ b/dist/include/mlx/backend/metal/unary.h @@ -0,0 +1,21 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +void unary_op_gpu( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +void unary_op_gpu_inplace( + const std::vector& inputs, + array& out, + const char* op, + const Stream& s); + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/metal/utils.h b/dist/include/mlx/backend/metal/utils.h new file mode 100644 index 0000000..e7784e5 --- /dev/null +++ b/dist/include/mlx/backend/metal/utils.h @@ -0,0 +1,84 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/backend/metal/device.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +std::string type_to_name(const Dtype& t); +std::string type_to_name(const array& a); + +// Compute the grid and block dimensions, check backend/common/utils.h for docs. +MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10); +MTL::Size get_2d_grid_dims(const Shape& shape, const Strides& strides); +MTL::Size +get_2d_grid_dims(const Shape& shape, const Strides& strides, size_t divisor); + +inline NS::String* make_string(std::ostringstream& os) { + std::string string = os.str(); + return NS::String::string(string.c_str(), NS::UTF8StringEncoding); +} + +inline void debug_set_stream_queue_label(MTL::CommandQueue* queue, int index) { +#ifdef MLX_METAL_DEBUG + std::ostringstream label; + label << "Stream " << index; + queue->setLabel(make_string(label)); +#endif +} + +inline void debug_set_primitive_buffer_label( + MTL::CommandBuffer* command_buffer, + Primitive& primitive) { +#ifdef MLX_METAL_DEBUG + std::ostringstream label; + if (auto cbuf_label = command_buffer->label(); cbuf_label) { + label << cbuf_label->utf8String(); + } + label << primitive.name(); + command_buffer->setLabel(make_string(label)); +#endif +} + +template +constexpr bool is_numeric_except_char = std::is_arithmetic_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v && !std::is_same_v; + +template +void concatenate(std::string& acc, T first) { + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } +} + +template +void concatenate(std::string& acc, T first, Args... args) { + if constexpr (is_numeric_except_char) { + acc += std::to_string(first); + } else { + acc += first; + } + concatenate(acc, args...); +} + +inline int get_work_per_thread(Dtype dtype) { + return std::max(1, 8 / dtype.size()); +} +inline int get_work_per_thread(Dtype dtype, size_t size) { + constexpr size_t wpt_threshold = 1 << 16; + return size < wpt_threshold ? 1 : std::max(1, 8 / dtype.size()); +} + +inline size_t ceildiv(size_t n, size_t m) { + return (n + m - 1) / m; +} + +} // namespace mlx::core diff --git a/dist/include/mlx/backend/no_gpu/apple_memory.h b/dist/include/mlx/backend/no_gpu/apple_memory.h new file mode 100644 index 0000000..7fdc530 --- /dev/null +++ b/dist/include/mlx/backend/no_gpu/apple_memory.h @@ -0,0 +1,16 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + size_t memsize = 0; + size_t length = sizeof(memsize); + sysctlbyname("hw.memsize", &memsize, &length, NULL, 0); + return memsize; +} + +} // namespace diff --git a/dist/include/mlx/backend/no_gpu/linux_memory.h b/dist/include/mlx/backend/no_gpu/linux_memory.h new file mode 100644 index 0000000..f909edc --- /dev/null +++ b/dist/include/mlx/backend/no_gpu/linux_memory.h @@ -0,0 +1,22 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace { + +size_t get_memory_size() { + struct sysinfo info; + + if (sysinfo(&info) != 0) { + return 0; + } + + size_t total_ram = info.totalram; + total_ram *= info.mem_unit; + + return total_ram; +} + +} // namespace diff --git a/dist/include/mlx/c/array.h b/dist/include/mlx/c/array.h new file mode 100644 index 0000000..fce8841 --- /dev/null +++ b/dist/include/mlx/c/array.h @@ -0,0 +1,379 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ARRAY_H +#define MLX_ARRAY_H + +#include "mlx/c/string.h" + +#include +#include +#include +#include + +#include "half.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_array Array + * MLX N-dimensional array object. + */ +/**@{*/ + +/** + * A N-dimensional array object. + */ +typedef struct mlx_array_ { + void* ctx; +} mlx_array; + +static mlx_array mlx_array_empty; + +/** + * Array element type. + */ +typedef enum mlx_dtype_ { + MLX_BOOL, + MLX_UINT8, + MLX_UINT16, + MLX_UINT32, + MLX_UINT64, + MLX_INT8, + MLX_INT16, + MLX_INT32, + MLX_INT64, + MLX_FLOAT16, + MLX_FLOAT32, + MLX_FLOAT64, + MLX_BFLOAT16, + MLX_COMPLEX64, +} mlx_dtype; + +/** + * Size of given mlx_dtype datatype in bytes. + */ +size_t mlx_dtype_size(mlx_dtype dtype); + +/** + * Get array description. + */ +int mlx_array_tostring(mlx_string* str, const mlx_array arr); + +/** + * New empty array. + */ +mlx_array mlx_array_new(void); + +/** + * Free an array. + */ +int mlx_array_free(mlx_array arr); + +/** + * New array from a bool scalar. + */ +mlx_array mlx_array_new_bool(bool val); +/** + * New array from a int scalar. + */ +mlx_array mlx_array_new_int(int val); +/** + * New array from a float32 scalar. + */ +mlx_array mlx_array_new_float32(float val); +/** + * New array from a float scalar. + * Same as float32. + */ +mlx_array mlx_array_new_float(float val); +/** + * New array from a float64 scalar. + */ +mlx_array mlx_array_new_float64(double val); +/** + * New array from a double scalar. + * Same as float64. + */ +mlx_array mlx_array_new_double(double val); +/** + * New array from a complex scalar. + */ +mlx_array mlx_array_new_complex(float real_val, float imag_val); +/** + * New array from existing buffer. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + */ +mlx_array mlx_array_new_data( + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); +/** + * Set array to provided src array. + */ +int mlx_array_set(mlx_array* arr, const mlx_array src); +/** + * Set array to a bool scalar. + */ +int mlx_array_set_bool(mlx_array* arr, bool val); +/** + * Set array to a int scalar. + */ +int mlx_array_set_int(mlx_array* arr, int val); +/** + * Set array to a float32 scalar. + */ +int mlx_array_set_float32(mlx_array* arr, float val); +/** + * Set array to a float scalar. + */ +int mlx_array_set_float(mlx_array* arr, float val); +/** + * Set array to a float64 scalar. + */ +int mlx_array_set_float64(mlx_array* arr, double val); +/** + * Set array to a double scalar. + */ +int mlx_array_set_double(mlx_array* arr, double val); +/** + * Set array to a complex scalar. + */ +int mlx_array_set_complex(mlx_array* arr, float real_val, float imag_val); +/** + * Set array to specified data and shape. + * @param arr Destination array. + * @param data A buffer which will be copied. + * @param shape Shape of the array. + * @param dim Number of dimensions (size of `shape`). + * @param dtype Type of array elements. + */ +int mlx_array_set_data( + mlx_array* arr, + const void* data, + const int* shape, + int dim, + mlx_dtype dtype); + +/** + * The size of the array's datatype in bytes. + */ +size_t mlx_array_itemsize(const mlx_array arr); +/** + * Number of elements in the array. + */ +size_t mlx_array_size(const mlx_array arr); +/** + * The number of bytes in the array. + */ +size_t mlx_array_nbytes(const mlx_array arr); +/** + * The array's dimension. + */ +size_t mlx_array_ndim(const mlx_array arr); +/** + * The shape of the array. + * Returns: a pointer to the sizes of each dimension. + */ +const int* mlx_array_shape(const mlx_array arr); +/** + * The strides of the array. + * Returns: a pointer to the sizes of each dimension. + */ +const size_t* mlx_array_strides(const mlx_array arr); +/** + * The shape of the array in a particular dimension. + */ +int mlx_array_dim(const mlx_array arr, int dim); +/** + * The array element type. + */ +mlx_dtype mlx_array_dtype(const mlx_array arr); + +/** + * Evaluate the array. + */ +int mlx_array_eval(mlx_array arr); + +/** + * Access the value of a scalar array. + */ +int mlx_array_item_bool(bool* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint8(uint8_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint16(uint16_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint32(uint32_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_uint64(uint64_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int8(int8_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int16(int16_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int32(int32_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_int64(int64_t* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float32(float* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float64(double* res, const mlx_array arr); +/** + * Access the value of a scalar array. + */ +int mlx_array_item_complex64(float _Complex* res, const mlx_array arr); + +#ifdef HAS_FLOAT16 +/** + * Access the value of a scalar array. + */ +int mlx_array_item_float16(float16_t* res, const mlx_array arr); +#endif + +#ifdef HAS_BFLOAT16 +/** + * Access the value of a scalar array. + */ +int mlx_array_item_bfloat16(bfloat16_t* res, const mlx_array arr); +#endif + +/** + * Returns a pointer to the array data, cast to `bool*`. + * Array must be evaluated, otherwise returns NULL. + */ +const bool* mlx_array_data_bool(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint8_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint8_t* mlx_array_data_uint8(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint16_t* mlx_array_data_uint16(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint32_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint32_t* mlx_array_data_uint32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `uint64_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const uint64_t* mlx_array_data_uint64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int8_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int8_t* mlx_array_data_int8(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int16_t* mlx_array_data_int16(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int32_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int32_t* mlx_array_data_int32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `int64_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const int64_t* mlx_array_data_int64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `float32*`. + * Array must be evaluated, otherwise returns NULL. + */ +const float* mlx_array_data_float32(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `float64*`. + * Array must be evaluated, otherwise returns NULL. + */ +const double* mlx_array_data_float64(const mlx_array arr); +/** + * Returns a pointer to the array data, cast to `_Complex*`. + * Array must be evaluated, otherwise returns NULL. + */ +const float _Complex* mlx_array_data_complex64(const mlx_array arr); + +#ifdef HAS_FLOAT16 +/** + * Returns a pointer to the array data, cast to `float16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const float16_t* mlx_array_data_float16(const mlx_array arr); +#endif + +#ifdef HAS_BFLOAT16 +/** + * Returns a pointer to the array data, cast to `bfloat16_t*`. + * Array must be evaluated, otherwise returns NULL. + */ +const bfloat16_t* mlx_array_data_bfloat16(const mlx_array arr); +#endif + +/** + * Check if the array is available. + * Internal function: use at your own risk. + */ +int _mlx_array_is_available(bool* res, const mlx_array arr); + +/** + * Wait on the array to be available. After this `_mlx_array_is_available` + * returns `true`. Internal function: use at your own risk. + */ +int _mlx_array_wait(const mlx_array arr); + +/** + * Whether the array is contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_contiguous(bool* res, const mlx_array arr); + +/** + * Whether the array's rows are contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_row_contiguous(bool* res, const mlx_array arr); + +/** + * Whether the array's columns are contiguous in memory. + * Internal function: use at your own risk. + */ +int _mlx_array_is_col_contiguous(bool* res, const mlx_array arr); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/closure.h b/dist/include/mlx/c/closure.h new file mode 100644 index 0000000..33f7115 --- /dev/null +++ b/dist/include/mlx/c/closure.h @@ -0,0 +1,197 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_CLOSURE_H +#define MLX_CLOSURE_H + +#include "mlx/c/array.h" +#include "mlx/c/map.h" +#include "mlx/c/optional.h" +#include "mlx/c/stream.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_closure Closures + * MLX closure objects. + */ +/**@{*/ + +typedef struct mlx_closure_ { + void* ctx; +} mlx_closure; +mlx_closure mlx_closure_new(void); +int mlx_closure_free(mlx_closure cls); +mlx_closure mlx_closure_new_func( + int (*fun)(mlx_vector_array*, const mlx_vector_array)); +mlx_closure mlx_closure_new_func_payload( + int (*fun)(mlx_vector_array*, const mlx_vector_array, void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_set(mlx_closure* cls, const mlx_closure src); +int mlx_closure_apply( + mlx_vector_array* res, + mlx_closure cls, + const mlx_vector_array input); + +mlx_closure mlx_closure_new_unary(int (*fun)(mlx_array*, const mlx_array)); + +typedef struct mlx_closure_kwargs_ { + void* ctx; +} mlx_closure_kwargs; +mlx_closure_kwargs mlx_closure_kwargs_new(void); +int mlx_closure_kwargs_free(mlx_closure_kwargs cls); +mlx_closure_kwargs mlx_closure_kwargs_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array)); +mlx_closure_kwargs mlx_closure_kwargs_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_map_string_to_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_kwargs_set( + mlx_closure_kwargs* cls, + const mlx_closure_kwargs src); +int mlx_closure_kwargs_apply( + mlx_vector_array* res, + mlx_closure_kwargs cls, + const mlx_vector_array input_0, + const mlx_map_string_to_array input_1); + +typedef struct mlx_closure_value_and_grad_ { + void* ctx; +} mlx_closure_value_and_grad; +mlx_closure_value_and_grad mlx_closure_value_and_grad_new(void); +int mlx_closure_value_and_grad_free(mlx_closure_value_and_grad cls); +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func( + int (*fun)(mlx_vector_array*, mlx_vector_array*, const mlx_vector_array)); +mlx_closure_value_and_grad mlx_closure_value_and_grad_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_array*, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_value_and_grad_set( + mlx_closure_value_and_grad* cls, + const mlx_closure_value_and_grad src); +int mlx_closure_value_and_grad_apply( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + mlx_closure_value_and_grad cls, + const mlx_vector_array input); + +typedef struct mlx_closure_custom_ { + void* ctx; +} mlx_closure_custom; +mlx_closure_custom mlx_closure_custom_new(void); +int mlx_closure_custom_free(mlx_closure_custom cls); +mlx_closure_custom mlx_closure_custom_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array)); +mlx_closure_custom mlx_closure_custom_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const mlx_vector_array, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_set( + mlx_closure_custom* cls, + const mlx_closure_custom src); +int mlx_closure_custom_apply( + mlx_vector_array* res, + mlx_closure_custom cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const mlx_vector_array input_2); + +typedef struct mlx_closure_custom_jvp_ { + void* ctx; +} mlx_closure_custom_jvp; +mlx_closure_custom_jvp mlx_closure_custom_jvp_new(void); +int mlx_closure_custom_jvp_free(mlx_closure_custom_jvp cls); +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num)); +mlx_closure_custom_jvp mlx_closure_custom_jvp_new_func_payload( + int (*fun)( + mlx_vector_array*, + const mlx_vector_array, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_jvp_set( + mlx_closure_custom_jvp* cls, + const mlx_closure_custom_jvp src); +int mlx_closure_custom_jvp_apply( + mlx_vector_array* res, + mlx_closure_custom_jvp cls, + const mlx_vector_array input_0, + const mlx_vector_array input_1, + const int* input_2, + size_t input_2_num); + +typedef struct mlx_closure_custom_vmap_ { + void* ctx; +} mlx_closure_custom_vmap; +mlx_closure_custom_vmap mlx_closure_custom_vmap_new(void); +int mlx_closure_custom_vmap_free(mlx_closure_custom_vmap cls); +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num)); +mlx_closure_custom_vmap mlx_closure_custom_vmap_new_func_payload( + int (*fun)( + mlx_vector_array*, + mlx_vector_int*, + const mlx_vector_array, + const int*, + size_t _num, + void*), + void* payload, + void (*dtor)(void*)); +int mlx_closure_custom_vmap_set( + mlx_closure_custom_vmap* cls, + const mlx_closure_custom_vmap src); +int mlx_closure_custom_vmap_apply( + mlx_vector_array* res_0, + mlx_vector_int* res_1, + mlx_closure_custom_vmap cls, + const mlx_vector_array input_0, + const int* input_1, + size_t input_1_num); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/compile.h b/dist/include/mlx/c/compile.h new file mode 100644 index 0000000..ca337b0 --- /dev/null +++ b/dist/include/mlx/c/compile.h @@ -0,0 +1,55 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_COMPILE_H +#define MLX_COMPILE_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup compile Compilation operations + */ +/**@{*/ +typedef enum mlx_compile_mode_ { + MLX_COMPILE_MODE_DISABLED, + MLX_COMPILE_MODE_NO_SIMPLIFY, + MLX_COMPILE_MODE_NO_FUSE, + MLX_COMPILE_MODE_ENABLED +} mlx_compile_mode; +int mlx_compile(mlx_closure* res, const mlx_closure fun, bool shapeless); +int mlx_detail_compile( + mlx_closure* res, + const mlx_closure fun, + uintptr_t fun_id, + bool shapeless, + const uint64_t* constants, + size_t constants_num); +int mlx_detail_compile_clear_cache(void); +int mlx_detail_compile_erase(uintptr_t fun_id); +int mlx_disable_compile(void); +int mlx_enable_compile(void); +int mlx_set_compile_mode(mlx_compile_mode mode); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/device.h b/dist/include/mlx/c/device.h new file mode 100644 index 0000000..cad67f8 --- /dev/null +++ b/dist/include/mlx/c/device.h @@ -0,0 +1,80 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_DEVICE_H +#define MLX_DEVICE_H + +#include + +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_device Device + * MLX device object. + */ +/**@{*/ + +/** + * A MLX device object. + */ +typedef struct mlx_device_ { + void* ctx; +} mlx_device; + +/** + * Device type. + */ +typedef enum mlx_device_type_ { MLX_CPU, MLX_GPU } mlx_device_type; + +/** + * Returns a new empty device. + */ +mlx_device mlx_device_new(void); + +/** + * Returns a new device of specified `type`, with specified `index`. + */ +mlx_device mlx_device_new_type(mlx_device_type type, int index); +/** + * Free a device. + */ +int mlx_device_free(mlx_device dev); +/** + * Set device to provided src device. + */ +int mlx_device_set(mlx_device* dev, const mlx_device src); +/** + * Get device description. + */ +int mlx_device_tostring(mlx_string* str, mlx_device dev); +/** + * Check if devices are the same. + */ +bool mlx_device_equal(mlx_device lhs, mlx_device rhs); +/** + * Returns the index of the device. + */ +int mlx_device_get_index(int* index, mlx_device dev); +/** + * Returns the type of the device. + */ +int mlx_device_get_type(mlx_device_type* type, mlx_device dev); +/** + * Returns the default MLX device. + */ +int mlx_get_default_device(mlx_device* dev); +/** + * Set the default MLX device. + */ +int mlx_set_default_device(mlx_device dev); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/distributed.h b/dist/include/mlx/c/distributed.h new file mode 100644 index 0000000..2c5733d --- /dev/null +++ b/dist/include/mlx/c/distributed.h @@ -0,0 +1,81 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_DISTRIBUTED_H +#define MLX_DISTRIBUTED_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup distributed Distributed collectives + */ +/**@{*/ +int mlx_distributed_all_gather( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream S); +int mlx_distributed_all_max( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_all_min( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_all_sum( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_recv( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_recv_like( + mlx_array* res, + const mlx_array x, + int src, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_send( + mlx_array* res, + const mlx_array x, + int dst, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +int mlx_distributed_sum_scatter( + mlx_array* res, + const mlx_array x, + const mlx_distributed_group group /* may be null */, + const mlx_stream s); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/distributed_group.h b/dist/include/mlx/c/distributed_group.h new file mode 100644 index 0000000..3cfccc8 --- /dev/null +++ b/dist/include/mlx/c/distributed_group.h @@ -0,0 +1,58 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_DISTRIBUTED_GROUP_H +#define MLX_DISTRIBUTED_GROUP_H + +#include + +#include "mlx/c/stream.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_distributed_group MLX distributed + */ +/**@{*/ + +/** + * A MLX distributed group object. + */ +typedef struct mlx_distributed_group_ { + void* ctx; +} mlx_distributed_group; + +/** + * Get the rank. + */ +int mlx_distributed_group_rank(mlx_distributed_group group); + +/** + * Get the group size. + */ +int mlx_distributed_group_size(mlx_distributed_group group); + +/** + * Split the group. + */ +mlx_distributed_group +mlx_distributed_group_split(mlx_distributed_group group, int color, int key); + +/** + * Check if distributed is available. + */ +bool mlx_distributed_is_available(void); + +/** + * Initialize distributed. + */ +mlx_distributed_group mlx_distributed_init(bool strict); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/error.h b/dist/include/mlx/c/error.h new file mode 100644 index 0000000..8c063a4 --- /dev/null +++ b/dist/include/mlx/c/error.h @@ -0,0 +1,41 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ERROR_H +#define MLX_ERROR_H + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_error Error management + */ +/**@{*/ + +typedef void (*mlx_error_handler_func)(const char* msg, void* data); + +/** + * Set the error handler. + */ +void mlx_set_error_handler( + mlx_error_handler_func handler, + void* data, + void (*dtor)(void*)); + +/** + * Throw an error. + */ +void _mlx_error(const char* file, const int line, const char* fmt, ...); + +/** + * Throw an error. Macro which passes file name and line number to _mlx_error(). + */ +#define mlx_error(...) _mlx_error(__FILE__, __LINE__, __VA_ARGS__) + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/export.h b/dist/include/mlx/c/export.h new file mode 100644 index 0000000..52cb283 --- /dev/null +++ b/dist/include/mlx/c/export.h @@ -0,0 +1,75 @@ +/* Copyright © 2023-2025 Apple Inc. */ + +#ifndef MLX_EXPORT_H +#define MLX_EXPORT_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup export Function serialization + */ +/**@{*/ +int mlx_export_function( + const char* file, + const mlx_closure fun, + const mlx_vector_array args, + bool shapeless); +int mlx_export_function_kwargs( + const char* file, + const mlx_closure_kwargs fun, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs, + bool shapeless); + +typedef struct mlx_function_exporter_ { + void* ctx; +} mlx_function_exporter; +mlx_function_exporter mlx_function_exporter_new( + const char* file, + const mlx_closure fun, + bool shapeless); +int mlx_function_exporter_free(mlx_function_exporter xfunc); +int mlx_function_exporter_apply( + const mlx_function_exporter xfunc, + const mlx_vector_array args); +int mlx_function_exporter_apply_kwargs( + const mlx_function_exporter xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); + +typedef struct mlx_imported_function_ { + void* ctx; +} mlx_imported_function; +mlx_imported_function mlx_imported_function_new(const char* file); +int mlx_imported_function_free(mlx_imported_function xfunc); +int mlx_imported_function_apply( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args); +int mlx_imported_function_apply_kwargs( + mlx_vector_array* res, + const mlx_imported_function xfunc, + const mlx_vector_array args, + const mlx_map_string_to_array kwargs); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/fast.h b/dist/include/mlx/c/fast.h new file mode 100644 index 0000000..7a8aba0 --- /dev/null +++ b/dist/include/mlx/c/fast.h @@ -0,0 +1,205 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_FAST_H +#define MLX_FAST_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup fast Fast custom operations + */ +/**@{*/ + +typedef struct mlx_fast_cuda_kernel_config_ { + void* ctx; +} mlx_fast_cuda_kernel_config; +mlx_fast_cuda_kernel_config mlx_fast_cuda_kernel_config_new(void); +void mlx_fast_cuda_kernel_config_free(mlx_fast_cuda_kernel_config cls); + +int mlx_fast_cuda_kernel_config_add_output_arg( + mlx_fast_cuda_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_set_grid( + mlx_fast_cuda_kernel_config cls, + int grid1, + int grid2, + int grid3); +int mlx_fast_cuda_kernel_config_set_thread_group( + mlx_fast_cuda_kernel_config cls, + int thread1, + int thread2, + int thread3); +int mlx_fast_cuda_kernel_config_set_init_value( + mlx_fast_cuda_kernel_config cls, + float value); +int mlx_fast_cuda_kernel_config_set_verbose( + mlx_fast_cuda_kernel_config cls, + bool verbose); +int mlx_fast_cuda_kernel_config_add_template_arg_dtype( + mlx_fast_cuda_kernel_config cls, + const char* name, + mlx_dtype dtype); +int mlx_fast_cuda_kernel_config_add_template_arg_int( + mlx_fast_cuda_kernel_config cls, + const char* name, + int value); +int mlx_fast_cuda_kernel_config_add_template_arg_bool( + mlx_fast_cuda_kernel_config cls, + const char* name, + bool value); + +typedef struct mlx_fast_cuda_kernel_ { + void* ctx; +} mlx_fast_cuda_kernel; + +mlx_fast_cuda_kernel mlx_fast_cuda_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + int shared_memory); + +void mlx_fast_cuda_kernel_free(mlx_fast_cuda_kernel cls); + +int mlx_fast_cuda_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_cuda_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_cuda_kernel_config config, + const mlx_stream stream); + +int mlx_fast_layer_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + const mlx_array bias /* may be null */, + float eps, + const mlx_stream s); + +typedef struct mlx_fast_metal_kernel_config_ { + void* ctx; +} mlx_fast_metal_kernel_config; +mlx_fast_metal_kernel_config mlx_fast_metal_kernel_config_new(void); +void mlx_fast_metal_kernel_config_free(mlx_fast_metal_kernel_config cls); + +int mlx_fast_metal_kernel_config_add_output_arg( + mlx_fast_metal_kernel_config cls, + const int* shape, + size_t size, + mlx_dtype dtype); +int mlx_fast_metal_kernel_config_set_grid( + mlx_fast_metal_kernel_config cls, + int grid1, + int grid2, + int grid3); +int mlx_fast_metal_kernel_config_set_thread_group( + mlx_fast_metal_kernel_config cls, + int thread1, + int thread2, + int thread3); +int mlx_fast_metal_kernel_config_set_init_value( + mlx_fast_metal_kernel_config cls, + float value); +int mlx_fast_metal_kernel_config_set_verbose( + mlx_fast_metal_kernel_config cls, + bool verbose); +int mlx_fast_metal_kernel_config_add_template_arg_dtype( + mlx_fast_metal_kernel_config cls, + const char* name, + mlx_dtype dtype); +int mlx_fast_metal_kernel_config_add_template_arg_int( + mlx_fast_metal_kernel_config cls, + const char* name, + int value); +int mlx_fast_metal_kernel_config_add_template_arg_bool( + mlx_fast_metal_kernel_config cls, + const char* name, + bool value); + +typedef struct mlx_fast_metal_kernel_ { + void* ctx; +} mlx_fast_metal_kernel; + +mlx_fast_metal_kernel mlx_fast_metal_kernel_new( + const char* name, + const mlx_vector_string input_names, + const mlx_vector_string output_names, + const char* source, + const char* header, + bool ensure_row_contiguous, + bool atomic_outputs); + +void mlx_fast_metal_kernel_free(mlx_fast_metal_kernel cls); + +int mlx_fast_metal_kernel_apply( + mlx_vector_array* outputs, + mlx_fast_metal_kernel cls, + const mlx_vector_array inputs, + const mlx_fast_metal_kernel_config config, + const mlx_stream stream); + +int mlx_fast_rms_norm( + mlx_array* res, + const mlx_array x, + const mlx_array weight /* may be null */, + float eps, + const mlx_stream s); +int mlx_fast_rope( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + int offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); +int mlx_fast_rope_dynamic( + mlx_array* res, + const mlx_array x, + int dims, + bool traditional, + mlx_optional_float base, + float scale, + const mlx_array offset, + const mlx_array freqs /* may be null */, + const mlx_stream s); +int mlx_fast_scaled_dot_product_attention( + mlx_array* res, + const mlx_array queries, + const mlx_array keys, + const mlx_array values, + float scale, + const char* mask_mode, + const mlx_array mask_arr /* may be null */, + const mlx_array sinks /* may be null */, + const mlx_stream s); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/fft.h b/dist/include/mlx/c/fft.h new file mode 100644 index 0000000..b7ef5e0 --- /dev/null +++ b/dist/include/mlx/c/fft.h @@ -0,0 +1,136 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_FFT_H +#define MLX_FFT_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup fft FFT operations + */ +/**@{*/ +int mlx_fft_fft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_fft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_fftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_fftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_ifft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_ifftshift( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_irfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_irfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_irfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_rfft( + mlx_array* res, + const mlx_array a, + int n, + int axis, + const mlx_stream s); +int mlx_fft_rfft2( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_fft_rfftn( + mlx_array* res, + const mlx_array a, + const int* n, + size_t n_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/half.h b/dist/include/mlx/c/half.h new file mode 100644 index 0000000..958d555 --- /dev/null +++ b/dist/include/mlx/c/half.h @@ -0,0 +1,26 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_HALF_H +#define MLX_HALF_H + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(__ARM_FEATURE_FP16_SCALAR_ARITHMETIC) || defined(__aarch64__) +#define HAS_FLOAT16 +#include +typedef __fp16 float16_t; +#endif + +#if defined(__ARM_FEATURE_BF16) || defined(__aarch64__) +#define HAS_BFLOAT16 +#include +typedef __bf16 bfloat16_t; +#endif + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/io.h b/dist/include/mlx/c/io.h new file mode 100644 index 0000000..2ec53e1 --- /dev/null +++ b/dist/include/mlx/c/io.h @@ -0,0 +1,61 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_IO_H +#define MLX_IO_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup io IO operations + */ +/**@{*/ +int mlx_load_reader( + mlx_array* res, + mlx_io_reader in_stream, + const mlx_stream s); +int mlx_load(mlx_array* res, const char* file, const mlx_stream s); +int mlx_load_safetensors_reader( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + mlx_io_reader in_stream, + const mlx_stream s); +int mlx_load_safetensors( + mlx_map_string_to_array* res_0, + mlx_map_string_to_string* res_1, + const char* file, + const mlx_stream s); +int mlx_save_writer(mlx_io_writer out_stream, const mlx_array a); +int mlx_save(const char* file, const mlx_array a); +int mlx_save_safetensors_writer( + mlx_io_writer in_stream, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); +int mlx_save_safetensors( + const char* file, + const mlx_map_string_to_array param, + const mlx_map_string_to_string metadata); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/io_types.h b/dist/include/mlx/c/io_types.h new file mode 100644 index 0000000..88349b5 --- /dev/null +++ b/dist/include/mlx/c/io_types.h @@ -0,0 +1,104 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_IO_TYPES_H +#define MLX_IO_TYPES_H + +#include + +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_io_types IO Types + * MLX IO type objects. + */ +/**@{*/ + +/** + * A MLX IO reader object. + */ +typedef struct mlx_io_reader_ { + void* ctx; +} mlx_io_reader; +/** + * A MLX IO writer object. + */ +typedef struct mlx_io_writer_ { + void* ctx; +} mlx_io_writer; + +/** + * Virtual table for custom IO reader and writer objects. + */ +typedef struct mlx_io_vtable_ { + bool (*is_open)(void*); + bool (*good)(void*); + size_t (*tell)(void*); + void (*seek)(void*, int64_t off, int whence); + void (*read)(void*, char* data, size_t n); + void (*read_at_offset)(void*, char* data, size_t n, size_t off); + void (*write)(void*, const char* data, size_t n); + const char* (*label)(void*); + void (*free)(void*); +} mlx_io_vtable; + +/** + * Returns a new custom IO reader. + * `vtable` operates on user descriptor `desc`. + */ +mlx_io_reader mlx_io_reader_new(void* desc, mlx_io_vtable vtable); + +/** + * Get IO reader user descriptor. + */ +int mlx_io_reader_descriptor(void** desc_, mlx_io_reader io); + +/** + * Get IO reader description. + */ +int mlx_io_reader_tostring(mlx_string* str_, mlx_io_reader io); + +/** + * Free IO reader. + * + * Note that MLX arrays are lazily evaluated, so the underlying object may + * be not freed right away. The ``free()`` callback from ``mlx_io_vtable`` + * will be called when the underlying object is actually freed. + */ +int mlx_io_reader_free(mlx_io_reader io); + +/** + * Returns a new custom IO writer. + * `vtable` operates on user descriptor `desc`. + */ +mlx_io_writer mlx_io_writer_new(void* desc, mlx_io_vtable vtable); + +/** + * Get IO writer user descriptor. + */ +int mlx_io_writer_descriptor(void** desc_, mlx_io_writer io); + +/** + * Get IO writer description. + */ +int mlx_io_writer_tostring(mlx_string* str_, mlx_io_writer io); + +/** + * Free IO writer. + * + * Note that MLX arrays are lazily evaluated, so the underlying object may + * be not freed right away. The ``free()`` callback from ``mlx_io_vtable`` + * will be called when the underlying object is actually freed. + */ +int mlx_io_writer_free(mlx_io_writer io); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/linalg.h b/dist/include/mlx/c/linalg.h new file mode 100644 index 0000000..ac0b323 --- /dev/null +++ b/dist/include/mlx/c/linalg.h @@ -0,0 +1,126 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_LINALG_H +#define MLX_LINALG_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup linalg Linear algebra operations + */ +/**@{*/ +int mlx_linalg_cholesky( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +int mlx_linalg_cholesky_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +int mlx_linalg_cross( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +int mlx_linalg_eig( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_eigh( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +int mlx_linalg_eigvals(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_eigvalsh( + mlx_array* res, + const mlx_array a, + const char* UPLO, + const mlx_stream s); +int mlx_linalg_inv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_lu(mlx_vector_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_lu_factor( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_norm( + mlx_array* res, + const mlx_array a, + double ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_norm_matrix( + mlx_array* res, + const mlx_array a, + const char* ord, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_norm_l2( + mlx_array* res, + const mlx_array a, + const int* axis /* may be null */, + size_t axis_num, + bool keepdims, + const mlx_stream s); +int mlx_linalg_pinv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_linalg_qr( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array a, + const mlx_stream s); +int mlx_linalg_solve( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_linalg_solve_triangular( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool upper, + const mlx_stream s); +int mlx_linalg_svd( + mlx_vector_array* res, + const mlx_array a, + bool compute_uv, + const mlx_stream s); +int mlx_linalg_tri_inv( + mlx_array* res, + const mlx_array a, + bool upper, + const mlx_stream s); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/map.h b/dist/include/mlx/c/map.h new file mode 100644 index 0000000..56abe84 --- /dev/null +++ b/dist/include/mlx/c/map.h @@ -0,0 +1,149 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_MAP_H +#define MLX_MAP_H + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_map Maps + * MLX map objects. + */ +/**@{*/ + +/** + * A string-to-array map + */ +typedef struct mlx_map_string_to_array_ { + void* ctx; +} mlx_map_string_to_array; + +/** + * Returns a new empty string-to-array map. + */ +mlx_map_string_to_array mlx_map_string_to_array_new(void); +/** + * Set map to provided src map. + */ +int mlx_map_string_to_array_set( + mlx_map_string_to_array* map, + const mlx_map_string_to_array src); +/** + * Free a string-to-array map. + */ +int mlx_map_string_to_array_free(mlx_map_string_to_array map); +/** + * Insert a new `value` at the specified `key` in the map. + */ +int mlx_map_string_to_array_insert( + mlx_map_string_to_array map, + const char* key, + const mlx_array value); +/** + * Returns the value indexed at the specified `key` in the map. + */ +int mlx_map_string_to_array_get( + mlx_array* value, + const mlx_map_string_to_array map, + const char* key); + +/** + * An iterator over a string-to-array map. + */ +typedef struct mlx_map_string_to_array_iterator_ { + void* ctx; + void* map_ctx; +} mlx_map_string_to_array_iterator; +/** + * Returns a new iterator over the given map. + */ +mlx_map_string_to_array_iterator mlx_map_string_to_array_iterator_new( + mlx_map_string_to_array map); +/** + * Free iterator. + */ +int mlx_map_string_to_array_iterator_free(mlx_map_string_to_array_iterator it); +/** + * Increment iterator. + */ +int mlx_map_string_to_array_iterator_next( + const char** key, + mlx_array* value, + mlx_map_string_to_array_iterator it); + +/** + * A string-to-string map + */ +typedef struct mlx_map_string_to_string_ { + void* ctx; +} mlx_map_string_to_string; + +/** + * Returns a new empty string-to-string map. + */ +mlx_map_string_to_string mlx_map_string_to_string_new(void); +/** + * Set map to provided src map. + */ +int mlx_map_string_to_string_set( + mlx_map_string_to_string* map, + const mlx_map_string_to_string src); +/** + * Free a string-to-string map. + */ +int mlx_map_string_to_string_free(mlx_map_string_to_string map); +/** + * Insert a new `value` at the specified `key` in the map. + */ +int mlx_map_string_to_string_insert( + mlx_map_string_to_string map, + const char* key, + const char* value); +/** + * Returns the value indexed at the specified `key` in the map. + */ +int mlx_map_string_to_string_get( + const char** value, + const mlx_map_string_to_string map, + const char* key); + +/** + * An iterator over a string-to-string map. + */ +typedef struct mlx_map_string_to_string_iterator_ { + void* ctx; + void* map_ctx; +} mlx_map_string_to_string_iterator; +/** + * Returns a new iterator over the given map. + */ +mlx_map_string_to_string_iterator mlx_map_string_to_string_iterator_new( + mlx_map_string_to_string map); +/** + * Free iterator. + */ +int mlx_map_string_to_string_iterator_free( + mlx_map_string_to_string_iterator it); +/** + * Increment iterator. + */ +int mlx_map_string_to_string_iterator_next( + const char** key, + const char** value, + mlx_map_string_to_string_iterator it); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/memory.h b/dist/include/mlx/c/memory.h new file mode 100644 index 0000000..03d70f9 --- /dev/null +++ b/dist/include/mlx/c/memory.h @@ -0,0 +1,45 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_MEMORY_H +#define MLX_MEMORY_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup memory Memory operations + */ +/**@{*/ +int mlx_clear_cache(void); +int mlx_get_active_memory(size_t* res); +int mlx_get_cache_memory(size_t* res); +int mlx_get_memory_limit(size_t* res); +int mlx_get_peak_memory(size_t* res); +int mlx_reset_peak_memory(void); +int mlx_set_cache_limit(size_t* res, size_t limit); +int mlx_set_memory_limit(size_t* res, size_t limit); +int mlx_set_wired_limit(size_t* res, size_t limit); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/metal.h b/dist/include/mlx/c/metal.h new file mode 100644 index 0000000..4589c2d --- /dev/null +++ b/dist/include/mlx/c/metal.h @@ -0,0 +1,48 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_METAL_H +#define MLX_METAL_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup metal Metal specific operations + */ +/**@{*/ + +typedef struct mlx_metal_device_info_t_ { + char architecture[256]; + size_t max_buffer_length; + size_t max_recommended_working_set_size; + size_t memory_size; +} mlx_metal_device_info_t; +mlx_metal_device_info_t mlx_metal_device_info(void); + +int mlx_metal_is_available(bool* res); +int mlx_metal_start_capture(const char* path); +int mlx_metal_stop_capture(void); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/mlx.h b/dist/include/mlx/c/mlx.h new file mode 100644 index 0000000..b62ea3b --- /dev/null +++ b/dist/include/mlx/c/mlx.h @@ -0,0 +1,33 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_ALL_H +#define MLX_ALL_H + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/compile.h" +#include "mlx/c/device.h" +#include "mlx/c/distributed.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/error.h" +#include "mlx/c/export.h" +#include "mlx/c/fast.h" +#include "mlx/c/fft.h" +#include "mlx/c/half.h" +#include "mlx/c/io.h" +#include "mlx/c/io_types.h" +#include "mlx/c/linalg.h" +#include "mlx/c/map.h" +#include "mlx/c/memory.h" +#include "mlx/c/metal.h" +#include "mlx/c/ops.h" +#include "mlx/c/optional.h" +#include "mlx/c/random.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/transforms.h" +#include "mlx/c/transforms_impl.h" +#include "mlx/c/vector.h" +#include "mlx/c/version.h" + +#endif diff --git a/dist/include/mlx/c/ops.h b/dist/include/mlx/c/ops.h new file mode 100644 index 0000000..24c03c8 --- /dev/null +++ b/dist/include/mlx/c/ops.h @@ -0,0 +1,1233 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_OPS_H +#define MLX_OPS_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup ops Core array operations + */ +/**@{*/ +int mlx_abs(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_add( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_addmm( + mlx_array* res, + const mlx_array c, + const mlx_array a, + const mlx_array b, + float alpha, + float beta, + const mlx_stream s); +int mlx_all_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_all_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_all( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_allclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +int mlx_any_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_any_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_any( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_arange( + mlx_array* res, + double start, + double stop, + double step, + mlx_dtype dtype, + const mlx_stream s); +int mlx_arccos(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arccosh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arcsin(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arcsinh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arctan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_arctan2( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_arctanh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_argmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_argmax( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_argmin_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_argmin( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_argpartition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +int mlx_argpartition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +int mlx_argsort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_argsort(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_array_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + bool equal_nan, + const mlx_stream s); +int mlx_as_strided( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const int64_t* strides, + size_t strides_num, + size_t offset, + const mlx_stream s); +int mlx_astype( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +int mlx_atleast_1d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_atleast_2d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_atleast_3d(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bitwise_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_bitwise_invert(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_bitwise_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_bitwise_xor( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_block_masked_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int block_size, + const mlx_array mask_out /* may be null */, + const mlx_array mask_lhs /* may be null */, + const mlx_array mask_rhs /* may be null */, + const mlx_stream s); +int mlx_broadcast_arrays( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_stream s); +int mlx_broadcast_to( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_ceil(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_clip( + mlx_array* res, + const mlx_array a, + const mlx_array a_min /* may be null */, + const mlx_array a_max /* may be null */, + const mlx_stream s); +int mlx_concatenate_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +int mlx_concatenate( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +int mlx_conjugate(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_contiguous( + mlx_array* res, + const mlx_array a, + bool allow_col_major, + const mlx_stream s); +int mlx_conv1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int groups, + const mlx_stream s); +int mlx_conv2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int groups, + const mlx_stream s); +int mlx_conv3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int groups, + const mlx_stream s); +int mlx_conv_general( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + const int* stride, + size_t stride_num, + const int* padding_lo, + size_t padding_lo_num, + const int* padding_hi, + size_t padding_hi_num, + const int* kernel_dilation, + size_t kernel_dilation_num, + const int* input_dilation, + size_t input_dilation_num, + int groups, + bool flip, + const mlx_stream s); +int mlx_conv_transpose1d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride, + int padding, + int dilation, + int output_padding, + int groups, + const mlx_stream s); +int mlx_conv_transpose2d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int padding_0, + int padding_1, + int dilation_0, + int dilation_1, + int output_padding_0, + int output_padding_1, + int groups, + const mlx_stream s); +int mlx_conv_transpose3d( + mlx_array* res, + const mlx_array input, + const mlx_array weight, + int stride_0, + int stride_1, + int stride_2, + int padding_0, + int padding_1, + int padding_2, + int dilation_0, + int dilation_1, + int dilation_2, + int output_padding_0, + int output_padding_1, + int output_padding_2, + int groups, + const mlx_stream s); +int mlx_copy(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cos(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cosh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_cummax( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cummin( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cumprod( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_cumsum( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_degrees(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_depends( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array dependencies); +int mlx_dequantize( + mlx_array* res, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + mlx_optional_dtype dtype, + const mlx_stream s); +int mlx_diag(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +int mlx_diagonal( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + const mlx_stream s); +int mlx_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_divmod( + mlx_vector_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_einsum( + mlx_array* res, + const char* subscripts, + const mlx_vector_array operands, + const mlx_stream s); +int mlx_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_erf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_erfinv(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_exp(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_expand_dims_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_expand_dims( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_expm1(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_eye( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype dtype, + const mlx_stream s); +int mlx_flatten( + mlx_array* res, + const mlx_array a, + int start_axis, + int end_axis, + const mlx_stream s); +int mlx_floor(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_floor_divide( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_from_fp8( + mlx_array* res, + const mlx_array x, + mlx_dtype dtype, + const mlx_stream s); +int mlx_full( + mlx_array* res, + const int* shape, + size_t shape_num, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +int mlx_full_like( + mlx_array* res, + const mlx_array a, + const mlx_array vals, + mlx_dtype dtype, + const mlx_stream s); +int mlx_gather( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const int* axes, + size_t axes_num, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); +int mlx_gather_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const int* slice_sizes, + size_t slice_sizes_num, + const mlx_stream s); +int mlx_gather_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool sorted_indices, + const mlx_stream s); +int mlx_gather_qmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + const mlx_array lhs_indices /* may be null */, + const mlx_array rhs_indices /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + bool sorted_indices, + const mlx_stream s); +int mlx_greater( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_greater_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_hadamard_transform( + mlx_array* res, + const mlx_array a, + mlx_optional_float scale, + const mlx_stream s); +int mlx_identity(mlx_array* res, int n, mlx_dtype dtype, const mlx_stream s); +int mlx_imag(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_inner( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_isclose( + mlx_array* res, + const mlx_array a, + const mlx_array b, + double rtol, + double atol, + bool equal_nan, + const mlx_stream s); +int mlx_isfinite(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isinf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isnan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isneginf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_isposinf(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_kron( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_left_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_less( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_less_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_linspace( + mlx_array* res, + double start, + double stop, + int num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_log(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log10(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log1p(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_log2(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_logaddexp( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logcumsumexp( + mlx_array* res, + const mlx_array a, + int axis, + bool reverse, + bool inclusive, + const mlx_stream s); +int mlx_logical_and( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logical_not(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_logical_or( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_logsumexp_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_logsumexp_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_logsumexp( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_masked_scatter( + mlx_array* res, + const mlx_array a, + const mlx_array mask, + const mlx_array src, + const mlx_stream s); +int mlx_matmul( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_max_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_max_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_max( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_maximum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_mean_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_mean_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_mean( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_median( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_meshgrid( + mlx_vector_array* res, + const mlx_vector_array arrays, + bool sparse, + const char* indexing, + const mlx_stream s); +int mlx_min_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_min_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_min( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_minimum( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_moveaxis( + mlx_array* res, + const mlx_array a, + int source, + int destination, + const mlx_stream s); +int mlx_multiply( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_nan_to_num( + mlx_array* res, + const mlx_array a, + float nan, + mlx_optional_float posinf, + mlx_optional_float neginf, + const mlx_stream s); +int mlx_negative(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_not_equal( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_number_of_elements( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool inverted, + mlx_dtype dtype, + const mlx_stream s); +int mlx_ones( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_ones_like(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_outer( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_pad( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const int* low_pad_size, + size_t low_pad_size_num, + const int* high_pad_size, + size_t high_pad_size_num, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +int mlx_pad_symmetric( + mlx_array* res, + const mlx_array a, + int pad_width, + const mlx_array pad_value, + const char* mode, + const mlx_stream s); +int mlx_partition_axis( + mlx_array* res, + const mlx_array a, + int kth, + int axis, + const mlx_stream s); +int mlx_partition( + mlx_array* res, + const mlx_array a, + int kth, + const mlx_stream s); +int mlx_power( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_prod_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_prod_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_prod( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_put_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +int mlx_qqmm( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array w_scales /* may be null */, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_quantize( + mlx_vector_array* res, + const mlx_array w, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_quantized_matmul( + mlx_array* res, + const mlx_array x, + const mlx_array w, + const mlx_array scales, + const mlx_array biases /* may be null */, + bool transpose, + mlx_optional_int group_size, + mlx_optional_int bits, + const char* mode, + const mlx_stream s); +int mlx_radians(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_real(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_reciprocal(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_remainder( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_repeat_axis( + mlx_array* res, + const mlx_array arr, + int repeats, + int axis, + const mlx_stream s); +int mlx_repeat( + mlx_array* res, + const mlx_array arr, + int repeats, + const mlx_stream s); +int mlx_reshape( + mlx_array* res, + const mlx_array a, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_right_shift( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_roll_axis( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + int axis, + const mlx_stream s); +int mlx_roll_axes( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_roll( + mlx_array* res, + const mlx_array a, + const int* shift, + size_t shift_num, + const mlx_stream s); +int mlx_round( + mlx_array* res, + const mlx_array a, + int decimals, + const mlx_stream s); +int mlx_rsqrt(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_scatter( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_add( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_add_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_add_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array values, + int axis, + const mlx_stream s); +int mlx_scatter_max( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_max_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_min( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_min_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_scatter_prod( + mlx_array* res, + const mlx_array a, + const mlx_vector_array indices, + const mlx_array updates, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_scatter_prod_single( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_array updates, + int axis, + const mlx_stream s); +int mlx_segmented_mm( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_array segments, + const mlx_stream s); +int mlx_sigmoid(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sign(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sin(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_sinh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_slice( + mlx_array* res, + const mlx_array a, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_dynamic( + mlx_array* res, + const mlx_array a, + const mlx_array start, + const int* axes, + size_t axes_num, + const int* slice_size, + size_t slice_size_num, + const mlx_stream s); +int mlx_slice_update( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const int* start, + size_t start_num, + const int* stop, + size_t stop_num, + const int* strides, + size_t strides_num, + const mlx_stream s); +int mlx_slice_update_dynamic( + mlx_array* res, + const mlx_array src, + const mlx_array update, + const mlx_array start, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_softmax_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool precise, + const mlx_stream s); +int mlx_softmax_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool precise, + const mlx_stream s); +int mlx_softmax( + mlx_array* res, + const mlx_array a, + bool precise, + const mlx_stream s); +int mlx_sort_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_sort(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_split( + mlx_vector_array* res, + const mlx_array a, + int num_splits, + int axis, + const mlx_stream s); +int mlx_split_sections( + mlx_vector_array* res, + const mlx_array a, + const int* indices, + size_t indices_num, + int axis, + const mlx_stream s); +int mlx_sqrt(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_square(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_squeeze_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_squeeze_axis( + mlx_array* res, + const mlx_array a, + int axis, + const mlx_stream s); +int mlx_squeeze(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_stack_axis( + mlx_array* res, + const mlx_vector_array arrays, + int axis, + const mlx_stream s); +int mlx_stack( + mlx_array* res, + const mlx_vector_array arrays, + const mlx_stream s); +int mlx_std_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_std_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_std( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_stop_gradient(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_subtract( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const mlx_stream s); +int mlx_sum_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + const mlx_stream s); +int mlx_sum_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + const mlx_stream s); +int mlx_sum( + mlx_array* res, + const mlx_array a, + bool keepdims, + const mlx_stream s); +int mlx_swapaxes( + mlx_array* res, + const mlx_array a, + int axis1, + int axis2, + const mlx_stream s); +int mlx_take_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +int mlx_take( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + const mlx_stream s); +int mlx_take_along_axis( + mlx_array* res, + const mlx_array a, + const mlx_array indices, + int axis, + const mlx_stream s); +int mlx_tan(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tanh(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tensordot( + mlx_array* res, + const mlx_array a, + const mlx_array b, + const int* axes_a, + size_t axes_a_num, + const int* axes_b, + size_t axes_b_num, + const mlx_stream s); +int mlx_tensordot_axis( + mlx_array* res, + const mlx_array a, + const mlx_array b, + int axis, + const mlx_stream s); +int mlx_tile( + mlx_array* res, + const mlx_array arr, + const int* reps, + size_t reps_num, + const mlx_stream s); +int mlx_to_fp8(mlx_array* res, const mlx_array x, const mlx_stream s); +int mlx_topk_axis( + mlx_array* res, + const mlx_array a, + int k, + int axis, + const mlx_stream s); +int mlx_topk(mlx_array* res, const mlx_array a, int k, const mlx_stream s); +int mlx_trace( + mlx_array* res, + const mlx_array a, + int offset, + int axis1, + int axis2, + mlx_dtype dtype, + const mlx_stream s); +int mlx_transpose_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + const mlx_stream s); +int mlx_transpose(mlx_array* res, const mlx_array a, const mlx_stream s); +int mlx_tri( + mlx_array* res, + int n, + int m, + int k, + mlx_dtype type, + const mlx_stream s); +int mlx_tril(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +int mlx_triu(mlx_array* res, const mlx_array x, int k, const mlx_stream s); +int mlx_unflatten( + mlx_array* res, + const mlx_array a, + int axis, + const int* shape, + size_t shape_num, + const mlx_stream s); +int mlx_var_axes( + mlx_array* res, + const mlx_array a, + const int* axes, + size_t axes_num, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_var_axis( + mlx_array* res, + const mlx_array a, + int axis, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_var( + mlx_array* res, + const mlx_array a, + bool keepdims, + int ddof, + const mlx_stream s); +int mlx_view( + mlx_array* res, + const mlx_array a, + mlx_dtype dtype, + const mlx_stream s); +int mlx_where( + mlx_array* res, + const mlx_array condition, + const mlx_array x, + const mlx_array y, + const mlx_stream s); +int mlx_zeros( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_stream s); +int mlx_zeros_like(mlx_array* res, const mlx_array a, const mlx_stream s); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/optional.h b/dist/include/mlx/c/optional.h new file mode 100644 index 0000000..ff9ea14 --- /dev/null +++ b/dist/include/mlx/c/optional.h @@ -0,0 +1,51 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_OPTIONAL_H +#define MLX_OPTIONAL_H + +#include + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_optional Optionals + * MLX optional scalars. + */ +/**@{*/ + +/** + * A int optional. + */ +typedef struct mlx_optional_int_ { + int value; + bool has_value; +} mlx_optional_int; + +/** + * A float optional. + */ +typedef struct mlx_optional_float_ { + float value; + bool has_value; +} mlx_optional_float; + +/** + * A dtype optional. + */ +typedef struct mlx_optional_dtype_ { + mlx_dtype value; + bool has_value; +} mlx_optional_dtype; + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/random.h b/dist/include/mlx/c/random.h new file mode 100644 index 0000000..5e9d216 --- /dev/null +++ b/dist/include/mlx/c/random.h @@ -0,0 +1,164 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_RANDOM_H +#define MLX_RANDOM_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup random Random number operations + */ +/**@{*/ +int mlx_random_bernoulli( + mlx_array* res, + const mlx_array p, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_bits( + mlx_array* res, + const int* shape, + size_t shape_num, + int width, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical_shape( + mlx_array* res, + const mlx_array logits, + int axis, + const int* shape, + size_t shape_num, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical_num_samples( + mlx_array* res, + const mlx_array logits_, + int axis, + int num_samples, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_categorical( + mlx_array* res, + const mlx_array logits, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_gumbel( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_key(mlx_array* res, uint64_t seed); +int mlx_random_laplace( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_multivariate_normal( + mlx_array* res, + const mlx_array mean, + const mlx_array cov, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_normal_broadcast( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array loc /* may be null */, + const mlx_array scale /* may be null */, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_normal( + mlx_array* res, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + float loc, + float scale, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_permutation( + mlx_array* res, + const mlx_array x, + int axis, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_permutation_arange( + mlx_array* res, + int x, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_randint( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_seed(uint64_t seed); +int mlx_random_split_num( + mlx_array* res, + const mlx_array key, + int num, + const mlx_stream s); +int mlx_random_split( + mlx_array* res_0, + mlx_array* res_1, + const mlx_array key, + const mlx_stream s); +int mlx_random_truncated_normal( + mlx_array* res, + const mlx_array lower, + const mlx_array upper, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +int mlx_random_uniform( + mlx_array* res, + const mlx_array low, + const mlx_array high, + const int* shape, + size_t shape_num, + mlx_dtype dtype, + const mlx_array key /* may be null */, + const mlx_stream s); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/stream.h b/dist/include/mlx/c/stream.h new file mode 100644 index 0000000..d5865b8 --- /dev/null +++ b/dist/include/mlx/c/stream.h @@ -0,0 +1,88 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_STREAM_H +#define MLX_STREAM_H + +#include + +#include "mlx/c/device.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_stream Stream + * MLX stream object. + */ +/**@{*/ + +/** + * A MLX stream object. + */ +typedef struct mlx_stream_ { + void* ctx; +} mlx_stream; + +/** + * Returns a new empty stream. + */ +mlx_stream mlx_stream_new(void); + +/** + * Returns a new stream on a device. + */ +mlx_stream mlx_stream_new_device(mlx_device dev); +/** + * Set stream to provided src stream. + */ +int mlx_stream_set(mlx_stream* stream, const mlx_stream src); +/** + * Free a stream. + */ +int mlx_stream_free(mlx_stream stream); +/** + * Get stream description. + */ +int mlx_stream_tostring(mlx_string* str, mlx_stream stream); +/** + * Check if streams are the same. + */ +bool mlx_stream_equal(mlx_stream lhs, mlx_stream rhs); +/** + * Return the device of the stream. + */ +int mlx_stream_get_device(mlx_device* dev, mlx_stream stream); +/** + * Return the index of the stream. + */ +int mlx_stream_get_index(int* index, mlx_stream stream); +/** + * Synchronize with the provided stream. + */ +int mlx_synchronize(mlx_stream stream); +/** + * Returns the default stream on the given device. + */ +int mlx_get_default_stream(mlx_stream* stream, mlx_device dev); +/** + * Set default stream. + */ +int mlx_set_default_stream(mlx_stream stream); +/** + * Returns the current default CPU stream. + */ +mlx_stream mlx_default_cpu_stream_new(void); + +/** + * Returns the current default GPU stream. + */ +mlx_stream mlx_default_gpu_stream_new(void); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/string.h b/dist/include/mlx/c/string.h new file mode 100644 index 0000000..0d2a356 --- /dev/null +++ b/dist/include/mlx/c/string.h @@ -0,0 +1,55 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_STRING_H +#define MLX_STRING_H + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_string String + * MLX string object. + */ +/**@{*/ + +/** + * A MLX string object. + */ +typedef struct mlx_string_ { + void* ctx; +} mlx_string; + +/** + * Returns a new empty string. + */ +mlx_string mlx_string_new(void); + +/** + * Returns a new string, copying contents from `str`, which must end with `\0`. + */ +mlx_string mlx_string_new_data(const char* str); + +/** + * Set string to src string. + */ +int mlx_string_set(mlx_string* str, const mlx_string src); + +/** + * Returns a pointer to the string contents. + * The pointer is valid for the life duration of the string. + */ +const char* mlx_string_data(mlx_string str); + +/** + * Free string. + */ +int mlx_string_free(mlx_string str); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/transforms.h b/dist/include/mlx/c/transforms.h new file mode 100644 index 0000000..c28d6e1 --- /dev/null +++ b/dist/include/mlx/c/transforms.h @@ -0,0 +1,66 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_TRANSFORMS_H +#define MLX_TRANSFORMS_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup transforms Transform operations + */ +/**@{*/ +int mlx_async_eval(const mlx_vector_array outputs); +int mlx_checkpoint(mlx_closure* res, const mlx_closure fun); +int mlx_custom_function( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp /* may be null */, + const mlx_closure_custom_jvp fun_jvp /* may be null */, + const mlx_closure_custom_vmap fun_vmap /* may be null */); +int mlx_custom_vjp( + mlx_closure* res, + const mlx_closure fun, + const mlx_closure_custom fun_vjp); +int mlx_eval(const mlx_vector_array outputs); +int mlx_jvp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array tangents); +int mlx_value_and_grad( + mlx_closure_value_and_grad* res, + const mlx_closure fun, + const int* argnums, + size_t argnums_num); +int mlx_vjp( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array primals, + const mlx_vector_array cotangents); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/transforms_impl.h b/dist/include/mlx/c/transforms_impl.h new file mode 100644 index 0000000..78b4cfd --- /dev/null +++ b/dist/include/mlx/c/transforms_impl.h @@ -0,0 +1,52 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_TRANSFORMS_IMPL_H +#define MLX_TRANSFORMS_IMPL_H + +#include +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/closure.h" +#include "mlx/c/distributed_group.h" +#include "mlx/c/io_types.h" +#include "mlx/c/map.h" +#include "mlx/c/stream.h" +#include "mlx/c/string.h" +#include "mlx/c/vector.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup transforms_impl Implementation detail operations + */ +/**@{*/ +int mlx_detail_vmap_replace( + mlx_vector_array* res, + const mlx_vector_array inputs, + const mlx_vector_array s_inputs, + const mlx_vector_array s_outputs, + const int* in_axes, + size_t in_axes_num, + const int* out_axes, + size_t out_axes_num); +int mlx_detail_vmap_trace( + mlx_vector_array* res_0, + mlx_vector_array* res_1, + const mlx_closure fun, + const mlx_vector_array inputs, + const int* in_axes, + size_t in_axes_num); +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/vector.h b/dist/include/mlx/c/vector.h new file mode 100644 index 0000000..81bcf74 --- /dev/null +++ b/dist/include/mlx/c/vector.h @@ -0,0 +1,133 @@ +/* Copyright © 2023-2024 Apple Inc. */ +/* */ +/* This file is auto-generated. Do not edit manually. */ +/* */ + +#ifndef MLX_VECTOR_H +#define MLX_VECTOR_H + +#include "mlx/c/array.h" +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * \defgroup mlx_vector Vectors + * MLX vector objects. + */ +/**@{*/ + +/** + * A vector of array. + */ +typedef struct mlx_vector_array_ { + void* ctx; +} mlx_vector_array; +mlx_vector_array mlx_vector_array_new(void); +int mlx_vector_array_set(mlx_vector_array* vec, const mlx_vector_array src); +int mlx_vector_array_free(mlx_vector_array vec); +mlx_vector_array mlx_vector_array_new_data(const mlx_array* data, size_t size); +mlx_vector_array mlx_vector_array_new_value(const mlx_array val); +int mlx_vector_array_set_data( + mlx_vector_array* vec, + const mlx_array* data, + size_t size); +int mlx_vector_array_set_value(mlx_vector_array* vec, const mlx_array val); +int mlx_vector_array_append_data( + mlx_vector_array vec, + const mlx_array* data, + size_t size); +int mlx_vector_array_append_value(mlx_vector_array vec, const mlx_array val); +size_t mlx_vector_array_size(mlx_vector_array vec); +int mlx_vector_array_get( + mlx_array* res, + const mlx_vector_array vec, + size_t idx); + +/** + * A vector of vector_array. + */ +typedef struct mlx_vector_vector_array_ { + void* ctx; +} mlx_vector_vector_array; +mlx_vector_vector_array mlx_vector_vector_array_new(void); +int mlx_vector_vector_array_set( + mlx_vector_vector_array* vec, + const mlx_vector_vector_array src); +int mlx_vector_vector_array_free(mlx_vector_vector_array vec); +mlx_vector_vector_array mlx_vector_vector_array_new_data( + const mlx_vector_array* data, + size_t size); +mlx_vector_vector_array mlx_vector_vector_array_new_value( + const mlx_vector_array val); +int mlx_vector_vector_array_set_data( + mlx_vector_vector_array* vec, + const mlx_vector_array* data, + size_t size); +int mlx_vector_vector_array_set_value( + mlx_vector_vector_array* vec, + const mlx_vector_array val); +int mlx_vector_vector_array_append_data( + mlx_vector_vector_array vec, + const mlx_vector_array* data, + size_t size); +int mlx_vector_vector_array_append_value( + mlx_vector_vector_array vec, + const mlx_vector_array val); +size_t mlx_vector_vector_array_size(mlx_vector_vector_array vec); +int mlx_vector_vector_array_get( + mlx_vector_array* res, + const mlx_vector_vector_array vec, + size_t idx); + +/** + * A vector of int. + */ +typedef struct mlx_vector_int_ { + void* ctx; +} mlx_vector_int; +mlx_vector_int mlx_vector_int_new(void); +int mlx_vector_int_set(mlx_vector_int* vec, const mlx_vector_int src); +int mlx_vector_int_free(mlx_vector_int vec); +mlx_vector_int mlx_vector_int_new_data(int* data, size_t size); +mlx_vector_int mlx_vector_int_new_value(int val); +int mlx_vector_int_set_data(mlx_vector_int* vec, int* data, size_t size); +int mlx_vector_int_set_value(mlx_vector_int* vec, int val); +int mlx_vector_int_append_data(mlx_vector_int vec, int* data, size_t size); +int mlx_vector_int_append_value(mlx_vector_int vec, int val); +size_t mlx_vector_int_size(mlx_vector_int vec); +int mlx_vector_int_get(int* res, const mlx_vector_int vec, size_t idx); + +/** + * A vector of string. + */ +typedef struct mlx_vector_string_ { + void* ctx; +} mlx_vector_string; +mlx_vector_string mlx_vector_string_new(void); +int mlx_vector_string_set(mlx_vector_string* vec, const mlx_vector_string src); +int mlx_vector_string_free(mlx_vector_string vec); +mlx_vector_string mlx_vector_string_new_data(const char** data, size_t size); +mlx_vector_string mlx_vector_string_new_value(const char* val); +int mlx_vector_string_set_data( + mlx_vector_string* vec, + const char** data, + size_t size); +int mlx_vector_string_set_value(mlx_vector_string* vec, const char* val); +int mlx_vector_string_append_data( + mlx_vector_string vec, + const char** data, + size_t size); +int mlx_vector_string_append_value(mlx_vector_string vec, const char* val); +size_t mlx_vector_string_size(mlx_vector_string vec); +int mlx_vector_string_get(char** res, const mlx_vector_string vec, size_t idx); + +/**@}*/ + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/c/version.h b/dist/include/mlx/c/version.h new file mode 100644 index 0000000..96dd238 --- /dev/null +++ b/dist/include/mlx/c/version.h @@ -0,0 +1,18 @@ +/* Copyright © 2023-2024 Apple Inc. */ + +#ifndef MLX_VERSION_H +#define MLX_VERSION_H + +#include "mlx/c/string.h" + +#ifdef __cplusplus +extern "C" { +#endif + +int mlx_version(mlx_string* str_); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/dist/include/mlx/compile.h b/dist/include/mlx/compile.h new file mode 100644 index 0000000..a076cfb --- /dev/null +++ b/dist/include/mlx/compile.h @@ -0,0 +1,44 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include "mlx/array.h" + +namespace mlx::core { + +enum class CompileMode { disabled, no_simplify, no_fuse, enabled }; + +/** Compile takes a function and returns a compiled function. */ +std::function(const std::vector&)> compile( + std::function(const std::vector&)> fun, + bool shapeless = false); + +std::function(const std::vector&)> compile( + std::vector (*fun)(const std::vector&), + bool shapeless = false); + +// Convert capture-less lambdas to function pointers. +template < + typename F, + typename = std::enable_if_t< + std::is_convertible_v())>>> +std::function(const std::vector&)> compile( + F&& f, + bool shapeless = false) { + return compile(+f, shapeless); +} + +/** Globally disable compilation. + * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also + * be used to disable compilation. + */ +void disable_compile(); + +/** Globally enable compilation. + * This will override the environment variable ``MLX_DISABLE_COMPILE``. + */ +void enable_compile(); + +/** Set the compiler mode to the given value. */ +void set_compile_mode(CompileMode mode); +} // namespace mlx::core diff --git a/dist/include/mlx/compile_impl.h b/dist/include/mlx/compile_impl.h new file mode 100644 index 0000000..ae8e26b --- /dev/null +++ b/dist/include/mlx/compile_impl.h @@ -0,0 +1,69 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" + +namespace mlx::core::detail { + +using ArraysAndExtra = std::pair, std::shared_ptr>; +using ArrayFnWithExtra = + std::function&)>; + +// This is not part of the general C++ API as calling with a bad id is a bad +// idea. +std::function(const std::vector&)> compile( + std::function(const std::vector&)> fun, + std::uintptr_t fun_id, + bool shapeless = false, + std::vector constants = {}); + +ArrayFnWithExtra compile( + ArrayFnWithExtra fun, + std::uintptr_t fun_id, + bool shapeless, + std::vector constants); + +// Erase cached compile functions +void compile_erase(std::uintptr_t fun_id); + +// Clear the compiler cache causing a recompilation of all compiled functions +// when called again. +void compile_clear_cache(); + +bool compile_available_for_device(const Device& device); + +std::tuple, std::vector, std::shared_ptr> +compile_trace( + const ArrayFnWithExtra& fun, + const std::vector& inputs, + bool shapeless); + +using ParentsMap = + std::unordered_map>>; + +// Traverses the graph to build a tape and a map of array ids to their parents +std::pair, ParentsMap> compile_dfs( + const std::vector& inputs, + std::vector& outputs, + const std::vector& original_inputs); + +// Simplify the tape. +void compile_simplify( + std::vector& tape, + ParentsMap& parents_map, + std::vector& outputs, + int passes); + +std::vector compile_replace( + const std::vector& tape, + const std::vector& trace_inputs, + const std::vector& trace_outputs, + const std::vector& inputs, + bool shapeless); + +void compile_validate_shapeless(const std::vector& tape); + +} // namespace mlx::core::detail diff --git a/dist/include/mlx/device.h b/dist/include/mlx/device.h new file mode 100644 index 0000000..80c624c --- /dev/null +++ b/dist/include/mlx/device.h @@ -0,0 +1,31 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +namespace mlx::core { + +struct Device { + enum class DeviceType { + cpu, + gpu, + }; + + static constexpr DeviceType cpu = DeviceType::cpu; + static constexpr DeviceType gpu = DeviceType::gpu; + + Device(DeviceType type, int index = 0) : type(type), index(index) {} + + DeviceType type; + int index; +}; + +const Device& default_device(); + +void set_default_device(const Device& d); + +bool operator==(const Device& lhs, const Device& rhs); +bool operator!=(const Device& lhs, const Device& rhs); + +bool is_available(const Device& d); + +} // namespace mlx::core diff --git a/dist/include/mlx/distributed/distributed.h b/dist/include/mlx/distributed/distributed.h new file mode 100644 index 0000000..a6971dd --- /dev/null +++ b/dist/include/mlx/distributed/distributed.h @@ -0,0 +1,60 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/utils.h" + +namespace mlx::core::distributed { + +// Forward declaration of the base group implementation. +namespace detail { +class GroupImpl; +}; + +/* Check if a communication backend is available */ +bool is_available(); +bool is_available(const std::string& bk); + +/** + * A distributed::Group represents a group of independent mlx processes that + * can communicate. We must also be able to create sub-groups from a group in + * order to define more granular communication. + */ +struct Group { + Group(std::shared_ptr group) : group_(std::move(group)) {} + + int rank() const; + int size() const; + + /** + * Split the group according to the provided color. Namely processes that use + * the same color will go to the same group. + * + * The key defines the rank of the processes in the new group. The smaller + * the key the smaller the rank. If the provided key is negative, then the + * rank in the current group is used. + */ + Group split(int color, int key = -1) const; + + const std::shared_ptr& raw_group() const { + return group_; + } + + private: + std::shared_ptr group_{nullptr}; +}; + +/** + * Initialize the distributed backend and return the group containing all + * discoverable processes. + * + * If strict is true then throw an error if we couldn't initialize the + * distributed subsystem. Otherwise simply return a singleton group which will + * render communication operations as no-op. + */ +Group init(bool strict = false, const std::string& bk = "any"); + +} // namespace mlx::core::distributed diff --git a/dist/include/mlx/distributed/distributed_impl.h b/dist/include/mlx/distributed/distributed_impl.h new file mode 100644 index 0000000..d889587 --- /dev/null +++ b/dist/include/mlx/distributed/distributed_impl.h @@ -0,0 +1,59 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::detail { + +/** + * Abstract base class of a distributed group implementation. + */ +class GroupImpl { + public: + virtual ~GroupImpl() {} + + // Choose the stream this communication group can operate on + virtual Stream communication_stream(StreamOrDevice s = {}) = 0; + + // Group operations + virtual int rank() = 0; + virtual int size() = 0; + virtual std::shared_ptr split(int color, int key = -1) = 0; + + // Actual communication operations + virtual void all_sum(const array& input, array& output, Stream stream) = 0; + virtual void all_gather(const array& input, array& output, Stream stream) = 0; + virtual void send(const array& input, int dst, Stream stream) = 0; + virtual void recv(array& out, int src, Stream stream) = 0; + virtual void all_max(const array& input, array& output, Stream stream) = 0; + virtual void all_min(const array& input, array& output, Stream stream) = 0; + virtual void + sum_scatter(const array& input, array& output, Stream stream) = 0; +}; + +/* Define the MLX stream that the communication should happen in. */ +Stream communication_stream(Group group, StreamOrDevice s = {}); + +/* Perform an all reduce sum operation */ +void all_sum(Group group, const array& input, array& output, Stream stream); + +/* Perform an all gather operation */ +void all_gather(Group group, const array& input, array& output, Stream stream); + +/** Send an array to the dst rank */ +void send(Group group, const array& input, int dst, Stream stream); + +/** Recv an array from the src rank */ +void recv(Group group, array& out, int src, Stream stream); + +/** Max reduction */ +void all_max(Group group, const array& input, array& output, Stream stream); + +/** Min reduction */ +void all_min(Group group, const array& input, array& output, Stream stream); + +/** Reduce scatter with average operation */ +void sum_scatter(Group group, const array& input, array& output, Stream stream); + +} // namespace mlx::core::distributed::detail diff --git a/dist/include/mlx/distributed/jaccl/jaccl.h b/dist/include/mlx/distributed/jaccl/jaccl.h new file mode 100644 index 0000000..d07f9cc --- /dev/null +++ b/dist/include/mlx/distributed/jaccl/jaccl.h @@ -0,0 +1,12 @@ +// Copyright © 2025 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::jaccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::jaccl diff --git a/dist/include/mlx/distributed/mpi/mpi.h b/dist/include/mlx/distributed/mpi/mpi.h new file mode 100644 index 0000000..cd11a47 --- /dev/null +++ b/dist/include/mlx/distributed/mpi/mpi.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::mpi { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::mpi diff --git a/dist/include/mlx/distributed/mpi/mpi_declarations.h b/dist/include/mlx/distributed/mpi/mpi_declarations.h new file mode 100644 index 0000000..99c1a9c --- /dev/null +++ b/dist/include/mlx/distributed/mpi/mpi_declarations.h @@ -0,0 +1,28 @@ +// Copyright © 2024 Apple Inc. + +// Constants + +#define MPI_SUCCESS 0 +#define MPI_ANY_SOURCE -1 +#define MPI_ANY_TAG -1 +#define MPI_IN_PLACE ((void*)1) +#define MPI_MAX_LIBRARY_VERSION_STRING 256 + +// Define all the types that we use so that we don't include which +// causes linker errors on some platforms. +// +// NOTE: We define everything for openmpi. + +typedef void* MPI_Comm; +typedef void* MPI_Datatype; +typedef void* MPI_Op; + +typedef void(MPI_User_function)(void*, void*, int*, MPI_Datatype*); + +typedef struct ompi_status_public_t { + int MPI_SOURCE; + int MPI_TAG; + int MPI_ERROR; + int _cancelled; + size_t _ucount; +} MPI_Status; diff --git a/dist/include/mlx/distributed/nccl/nccl.h b/dist/include/mlx/distributed/nccl/nccl.h new file mode 100644 index 0000000..5370d2d --- /dev/null +++ b/dist/include/mlx/distributed/nccl/nccl.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::nccl { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::nccl diff --git a/dist/include/mlx/distributed/ops.h b/dist/include/mlx/distributed/ops.h new file mode 100644 index 0000000..7688a5f --- /dev/null +++ b/dist/include/mlx/distributed/ops.h @@ -0,0 +1,56 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/distributed/distributed.h" +#include "mlx/utils.h" + +namespace mlx::core::distributed { + +array all_sum( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array all_gather( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice S = {}); + +array send( + const array& x, + int dst, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array recv( + Shape shape, + Dtype dtype, + int src, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array recv_like( + const array& x, + int src, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array all_max( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array all_min( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +array sum_scatter( + const array& x, + std::optional group = std::nullopt, + StreamOrDevice s = {}); + +} // namespace mlx::core::distributed diff --git a/dist/include/mlx/distributed/primitives.h b/dist/include/mlx/distributed/primitives.h new file mode 100644 index 0000000..18a0d65 --- /dev/null +++ b/dist/include/mlx/distributed/primitives.h @@ -0,0 +1,156 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/distributed_impl.h" +#include "mlx/primitives.h" + +namespace mlx::core::distributed { + +class DistPrimitive : public Primitive { + public: + DistPrimitive(Stream stream, Group group) + : Primitive(stream), group_(group) {} + + const Group& group() const { + return group_; + } + + private: + Group group_; +}; + +class AllReduce : public DistPrimitive { + public: + enum ReduceType { And, Or, Sum, Prod, Min, Max }; + + AllReduce(Stream stream, Group group, ReduceType reduce_type) + : DistPrimitive(stream, group), reduce_type_(reduce_type) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + const char* name() const override { + switch (reduce_type_) { + case And: + return "And AllReduce"; + case Or: + return "Or AllReduce"; + case Sum: + return "Sum AllReduce"; + case Prod: + return "Prod AllReduce"; + case Min: + return "Min AllReduce"; + case Max: + return "Max AllReduce"; + } + return ""; + } + + private: + ReduceType reduce_type_; +}; + +class AllGather : public DistPrimitive { + public: + AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(AllGather); +}; + +class Send : public DistPrimitive { + public: + Send(Stream stream, Group group, int dst) + : DistPrimitive(stream, group), dst_(dst) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + DEFINE_NAME(Send); + + private: + int dst_; +}; + +class Recv : public DistPrimitive { + public: + Recv(Stream stream, Group group, int src) + : DistPrimitive(stream, group), src_(src) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(Recv); + + private: + int src_; +}; + +class ReduceScatter : public DistPrimitive { + public: + enum ReduceType { Sum, Min, Max }; + ReduceScatter(Stream stream, Group group, ReduceType reduce_type) + : DistPrimitive(stream, group), reduce_type_(reduce_type) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "Sum ReduceScatter"; + case Min: + return "Min ReduceScatter"; + case Max: + return "Max ReduceScatter"; + } + return ""; + } + + private: + ReduceType reduce_type_; +}; +} // namespace mlx::core::distributed diff --git a/dist/include/mlx/distributed/reduction_ops.h b/dist/include/mlx/distributed/reduction_ops.h new file mode 100644 index 0000000..02777be --- /dev/null +++ b/dist/include/mlx/distributed/reduction_ops.h @@ -0,0 +1,38 @@ +// Copyright © 2025 Apple Inc. + +namespace mlx::core::distributed::detail { + +template +struct SumOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output += *input; + input++; + output++; + } + } +}; + +template +struct MaxOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output = std::max(*output, *input); + input++; + output++; + } + } +}; + +template +struct MinOp { + void operator()(const T* input, T* output, size_t N) const { + while (N-- > 0) { + *output = std::min(*output, *input); + input++; + output++; + } + } +}; + +} // namespace mlx::core::distributed::detail diff --git a/dist/include/mlx/distributed/ring/ring.h b/dist/include/mlx/distributed/ring/ring.h new file mode 100644 index 0000000..e0b3fd0 --- /dev/null +++ b/dist/include/mlx/distributed/ring/ring.h @@ -0,0 +1,12 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/distributed/distributed.h" + +namespace mlx::core::distributed::ring { + +using GroupImpl = mlx::core::distributed::detail::GroupImpl; + +bool is_available(); +std::shared_ptr init(bool strict = false); + +} // namespace mlx::core::distributed::ring diff --git a/dist/include/mlx/distributed/utils.h b/dist/include/mlx/distributed/utils.h new file mode 100644 index 0000000..213dd59 --- /dev/null +++ b/dist/include/mlx/distributed/utils.h @@ -0,0 +1,67 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include +#include + +namespace mlx::core::distributed::detail { + +struct address_t { + sockaddr_storage addr; + socklen_t len; + + const sockaddr* get() const { + return (struct sockaddr*)&addr; + } +}; + +/** + * Parse a sockaddr from an ip and port provided as strings. + */ +address_t parse_address(const std::string& ip, const std::string& port); + +/** + * Parse a sockaddr provided as an : string. + */ +address_t parse_address(const std::string& ip_port); + +/** + * Small wrapper over a TCP socket to simplify initiating connections. + */ +class TCPSocket { + public: + TCPSocket(const char* tag); + TCPSocket(const TCPSocket&) = delete; + TCPSocket& operator=(const TCPSocket&) = delete; + TCPSocket(TCPSocket&& s); + TCPSocket& operator=(TCPSocket&&); + ~TCPSocket(); + + void listen(const char* tag, const address_t& addr); + TCPSocket accept(const char* tag); + + void send(const char* tag, const void* data, size_t len); + void recv(const char* tag, void* data, size_t len); + + int detach(); + + operator int() const { + return sock_; + } + + static TCPSocket connect( + const char* tag, + const address_t& addr, + int num_retries = 1, + int wait = 0, + std::function cb = nullptr); + + private: + TCPSocket(int sock); + + int sock_; +}; + +} // namespace mlx::core::distributed::detail diff --git a/dist/include/mlx/dtype.h b/dist/include/mlx/dtype.h new file mode 100644 index 0000000..e02b6ca --- /dev/null +++ b/dist/include/mlx/dtype.h @@ -0,0 +1,115 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/types/complex.h" +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct Dtype { + enum class Val { + bool_, + uint8, + uint16, + uint32, + uint64, + int8, + int16, + int32, + int64, + float16, + float32, + float64, + bfloat16, + complex64, + }; + + enum class Kind { + b, /* bool */ + u, /* unsigned int */ + i, /* signed int */ + f, /* float */ + c, /* complex */ + V, /* void - used for brain float */ + }; + + enum class Category { + complexfloating, + floating, + inexact, + signedinteger, + unsignedinteger, + integer, + number, + generic + }; + + constexpr explicit Dtype(Val val, uint8_t size) : val_(val), size_(size) {} + + constexpr operator Val() const { + return val_; + } + constexpr Val val() const { + return val_; + } + constexpr uint8_t size() const { + return size_; + } + + private: + Val val_; + uint8_t size_; +}; + +inline constexpr Dtype bool_{Dtype::Val::bool_, sizeof(bool)}; + +inline constexpr Dtype uint8{Dtype::Val::uint8, sizeof(uint8_t)}; +inline constexpr Dtype uint16{Dtype::Val::uint16, sizeof(uint16_t)}; +inline constexpr Dtype uint32{Dtype::Val::uint32, sizeof(uint32_t)}; +inline constexpr Dtype uint64{Dtype::Val::uint64, sizeof(uint64_t)}; + +inline constexpr Dtype int8{Dtype::Val::int8, sizeof(int8_t)}; +inline constexpr Dtype int16{Dtype::Val::int16, sizeof(int16_t)}; +inline constexpr Dtype int32{Dtype::Val::int32, sizeof(int32_t)}; +inline constexpr Dtype int64{Dtype::Val::int64, sizeof(int64_t)}; + +inline constexpr Dtype float16{Dtype::Val::float16, sizeof(uint16_t)}; +inline constexpr Dtype float32{Dtype::Val::float32, sizeof(float)}; +inline constexpr Dtype float64{Dtype::Val::float64, sizeof(double)}; +inline constexpr Dtype bfloat16{Dtype::Val::bfloat16, sizeof(uint16_t)}; +inline constexpr Dtype complex64{Dtype::Val::complex64, sizeof(complex64_t)}; + +inline constexpr Dtype::Category complexfloating = + Dtype::Category::complexfloating; +inline constexpr Dtype::Category floating = Dtype::Category::floating; +inline constexpr Dtype::Category inexact = Dtype::Category::inexact; +inline constexpr Dtype::Category signedinteger = Dtype::Category::signedinteger; +inline constexpr Dtype::Category unsignedinteger = + Dtype::Category::unsignedinteger; +inline constexpr Dtype::Category integer = Dtype::Category::integer; +inline constexpr Dtype::Category number = Dtype::Category::number; +inline constexpr Dtype::Category generic = Dtype::Category::generic; + +bool issubdtype(const Dtype& a, const Dtype& b); +bool issubdtype(const Dtype::Category& a, const Dtype& b); +bool issubdtype(const Dtype& a, const Dtype::Category& b); +bool issubdtype(const Dtype::Category& a, const Dtype::Category& b); + +Dtype promote_types(const Dtype& t1, const Dtype& t2); + +inline uint8_t size_of(const Dtype& t) { + return t.size(); +} + +Dtype::Kind kindof(const Dtype& t); + +template +struct TypeToDtype { + operator Dtype(); +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/dtype_utils.h b/dist/include/mlx/dtype_utils.h new file mode 100644 index 0000000..47c6ed6 --- /dev/null +++ b/dist/include/mlx/dtype_utils.h @@ -0,0 +1,119 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +#include "mlx/dtype.h" +#include "mlx/utils.h" + +namespace mlx::core { + +// Return string representation of dtype. +const char* dtype_to_string(Dtype arg); + +#define MLX_INTERNAL_DTYPE_SWITCH_CASE(DTYPE, TYPE) \ + case DTYPE: \ + f(type_identity{}); \ + break + +#define MLX_INTERNAL_DTYPE_SWITCH_INTS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int8, int8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int16, int16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int32, int32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(int64, int64_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint8, uint8_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint16, uint16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint32, uint32_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(uint64, uint64_t) + +#define MLX_INTERNAL_DTYPE_SWITCH_FLOATS() \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float16, float16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(bfloat16, bfloat16_t); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float32, float); \ + MLX_INTERNAL_DTYPE_SWITCH_CASE(float64, double) + +// This already exists in C++20 but in C++20 we can also just use templated +// lambdas which will make this so much nicer. +template +struct type_identity { + using type = T; +}; + +#define MLX_GET_TYPE(x) typename decltype(x)::type +#define MLX_GET_VALUE(x) decltype(x)::value + +template +void dispatch_all_types(Dtype dt, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); + } +} + +template +void dispatch_int_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + default: + std::ostringstream msg; + msg << tag << " Only integer types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only float types supported but " << dt << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_inexact_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + MLX_INTERNAL_DTYPE_SWITCH_CASE(complex64, complex64_t); + default: + std::ostringstream msg; + msg << tag << " Only inexact (float/complex) types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_int_float_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only integer and float types supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +template +void dispatch_real_types(Dtype dt, std::string_view tag, F&& f) { + switch (dt) { + MLX_INTERNAL_DTYPE_SWITCH_CASE(bool_, bool); + MLX_INTERNAL_DTYPE_SWITCH_INTS(); + MLX_INTERNAL_DTYPE_SWITCH_FLOATS(); + default: + std::ostringstream msg; + msg << tag << " Only real numbers supported but " << dt + << " was provided"; + throw std::invalid_argument(msg.str()); + } +} + +} // namespace mlx::core diff --git a/dist/include/mlx/einsum.h b/dist/include/mlx/einsum.h new file mode 100644 index 0000000..f57e9a7 --- /dev/null +++ b/dist/include/mlx/einsum.h @@ -0,0 +1,22 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/array.h" +#include "mlx/utils.h" + +namespace mlx::core { + +std::pair>, std::string> einsum_path( + const std::string& subscripts, + const std::vector& operands); + +array einsum( + const std::string& subscripts, + const std::vector& operands, + StreamOrDevice s = {}); + +} // namespace mlx::core diff --git a/dist/include/mlx/event.h b/dist/include/mlx/event.h new file mode 100644 index 0000000..66a6a75 --- /dev/null +++ b/dist/include/mlx/event.h @@ -0,0 +1,58 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include +#include + +#include "mlx/stream.h" + +namespace mlx::core { + +class Event { + public: + Event() {}; + explicit Event(Stream stream); + + // Wait for the event to be signaled at its current value + void wait(); + + // Wait in the given stream for the event to be signaled at its current value + void wait(Stream stream); + + // Signal the event at its current value in the given stream + void signal(Stream stream); + + // Check if the event has been signaled at its current value + bool is_signaled() const; + + // Check if the event is valid + bool valid() const { + return event_ != nullptr; + } + + uint64_t value() const { + return value_; + } + + void set_value(uint64_t v) { + value_ = v; + } + + const Stream& stream() const { + if (!valid()) { + throw std::runtime_error( + "[Event::stream] Cannot access stream on invalid event."); + } + return stream_; + } + + private: + // Default constructed stream should never be used + // since the event is not yet valid + Stream stream_{0, Device::cpu}; + std::shared_ptr event_{nullptr}; + uint64_t value_{0}; +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/export.h b/dist/include/mlx/export.h new file mode 100644 index 0000000..0a8e9fb --- /dev/null +++ b/dist/include/mlx/export.h @@ -0,0 +1,136 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include "mlx/array.h" + +namespace mlx::core { + +using Args = std::vector; +using Kwargs = std::unordered_map; + +// Possible types for a Primitive's state +using StateT = std::variant< + bool, + int, + size_t, + float, + double, + Dtype, + Shape, + Strides, + std::vector, + std::vector, + std::vector>, + std::vector>, + std::optional, + std::string>; + +using ExportCallbackInput = std::unordered_map< + std::string, + std::variant< + std::vector>, + std::vector>, + std::vector>, + std::vector, + std::string>>; +using ExportCallback = std::function; + +struct FunctionExporter; + +/** + * Make an exporter to save multiple traces of a given function to + * the same file. + */ +FunctionExporter exporter( + const std::string& file, + const std::function(const Args&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const std::string& file, + const std::function(const Kwargs&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const std::string& path, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless = false); + +/** + * Export a function to a file. + */ +void export_function( + const std::string& file, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless = false); + +void export_function( + const std::string& file, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless = false); + +void export_function( + const std::string& file, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless = false); + +struct ImportedFunction; + +/** + * Import a function from a file. + */ +ImportedFunction import_function(const std::string& file); + +/** + * Make an exporter to export multiple traces of a given function with the same + * callback. + */ +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + bool shapeless = false); + +FunctionExporter exporter( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + bool shapeless = false); + +/** + * Export a function with a callback. + */ +void export_function( + const ExportCallback& callback, + const std::function(const Args&)>& fun, + const Args& args, + bool shapeless = false); + +void export_function( + const ExportCallback& callback, + const std::function(const Kwargs&)>& fun, + const Kwargs& kwargs, + bool shapeless = false); + +void export_function( + const ExportCallback& callback, + const std::function(const Args&, const Kwargs&)>& fun, + const Args& args, + const Kwargs& kwargs, + bool shapeless = false); + +} // namespace mlx::core + +#include "mlx/export_impl.h" diff --git a/dist/include/mlx/export_impl.h b/dist/include/mlx/export_impl.h new file mode 100644 index 0000000..be215aa --- /dev/null +++ b/dist/include/mlx/export_impl.h @@ -0,0 +1,98 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/io/load.h" + +#pragma once + +namespace mlx::core { + +struct FunctionTable; + +struct FunctionExporter { + void operator()(const std::initializer_list& args) { + this->operator()(Args(args)); + } + void operator()(const Args& args); + void operator()(const Kwargs& kwargs); + void operator()(const Args& args, const Kwargs& kwargs); + + void close(); + + FunctionExporter(const FunctionExporter&) = delete; + FunctionExporter& operator=(const FunctionExporter&) = delete; + FunctionExporter(FunctionExporter&& other) = default; + + private: + friend FunctionExporter exporter( + const std::string&, + const std::function(const Args&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const std::string&, + const std::function(const Kwargs&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const std::string&, + const std::function(const Args&, const Kwargs&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Kwargs&)>&, + bool shapeless); + + friend FunctionExporter exporter( + const ExportCallback&, + const std::function(const Args&, const Kwargs&)>&, + bool shapeless); + + FunctionExporter( + const std::string& file, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless); + + FunctionExporter( + const ExportCallback& callback, + std::function(const Args&, const Kwargs&)> fun, + bool shapeless); + + io::FileWriter os; + ExportCallback callback; + std::function(const Args&, const Kwargs& kwargs)> fun; + void export_function(const Args& args, const Kwargs& kwargs); + void export_with_callback( + const std::vector& inputs, + const std::vector& outputs, + const std::vector& tape, + const std::vector& kwarg_keys); + std::unordered_map constants; + int count{0}; + bool closed{false}; + std::shared_ptr ftable; +}; + +struct ImportedFunction { + std::vector operator()( + const std::initializer_list& args) const { + return this->operator()(Args(args)); + } + std::vector operator()(const Args& args) const; + std::vector operator()(const Kwargs& kwargs) const; + std::vector operator()(const Args& args, const Kwargs& kwargs) const; + + private: + ImportedFunction(const std::string& file); + friend ImportedFunction import_function(const std::string&); + ImportedFunction(); + + std::shared_ptr ftable; +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/fast.h b/dist/include/mlx/fast.h new file mode 100644 index 0000000..0884bac --- /dev/null +++ b/dist/include/mlx/fast.h @@ -0,0 +1,102 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/utils.h" + +namespace mlx::core::fast { + +array rms_norm( + const array& x, + const std::optional& weight, + float eps, + StreamOrDevice s = {}); + +array layer_norm( + const array& x, + const std::optional& weight, + const std::optional& bias, + float eps, + StreamOrDevice s = {}); + +array rope( + const array& x, + int dims, + bool traditional, + std::optional base, + float scale, + int offset, + const std::optional& freqs = std::nullopt, + StreamOrDevice s = {}); + +array rope( + const array& x, + int dims, + bool traditional, + std::optional base, + float scale, + const array& offset, + const std::optional& freqs = std::nullopt, + StreamOrDevice s = {}); + +/** Computes: O = softmax(Q @ K.T) @ V **/ +array scaled_dot_product_attention( + const array& queries, + const array& keys, + const array& values, + const float scale, + const std::string& mask_mode = "", + std::optional mask_arr = {}, + const std::optional& sinks = {}, + StreamOrDevice s = {}); + +using TemplateArg = std::variant; +using ScalarArg = std::variant; + +using CustomKernelFunction = std::function( + const std::vector&, + const std::vector&, + const std::vector&, + std::tuple, + std::tuple, + std::vector>, + std::optional, + bool, + StreamOrDevice)>; + +CustomKernelFunction metal_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + bool atomic_outputs = false); + +CustomKernelFunction cuda_kernel( + const std::string& name, + const std::vector& input_names, + const std::vector& output_names, + const std::string& source, + const std::string& header = "", + bool ensure_row_contiguous = true, + int shared_memory = 0); + +std::vector precompiled_cuda_kernel( + const std::string& name, + const std::string& compiled_source, + const std::vector& inputs, + const std::vector& output_shapes, + const std::vector& output_dtypes, + const std::vector& scalars, + std::tuple grid, + std::tuple threadgroup, + int shared_memory = 0, + std::optional init_value = std::nullopt, + bool ensure_row_contiguous = false, + StreamOrDevice s = {}); + +} // namespace mlx::core::fast diff --git a/dist/include/mlx/fast_primitives.h b/dist/include/mlx/fast_primitives.h new file mode 100644 index 0000000..4434830 --- /dev/null +++ b/dist/include/mlx/fast_primitives.h @@ -0,0 +1,427 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/primitives.h" + +namespace mlx::core::fast { + +// Custom primitive accepts a fallback function which it uses for +// transformations. Transformations are virtual so that derived classes may +// override the default behavior. +class Custom : public Primitive { + public: + explicit Custom( + Stream stream, + std::function(std::vector)> fallback) + : Primitive(stream), fallback_(std::move(fallback)) {} + + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes) override; + + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums) override; + + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + protected: + std::function(std::vector)> fallback_; +}; + +class RMSNorm : public Custom { + public: + RMSNorm( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + static bool use_fallback(Stream stream); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(RMSNorm) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class RMSNormVJP : public Custom { + public: + RMSNormVJP( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(RMSNormVJP) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class LayerNorm : public Custom { + public: + LayerNorm( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + static bool use_fallback(Stream s); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(LayerNorm) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class LayerNormVJP : public Custom { + public: + LayerNormVJP( + Stream stream, + std::function(std::vector)> fallback, + float eps) + : Custom(stream, std::move(fallback)), eps_(eps) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(LayerNormVJP) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(nullptr, eps_); + } + + private: + float eps_; +}; + +class RoPE : public Custom { + public: + RoPE( + Stream stream, + std::function(std::vector)> fallback, + int dims, + bool traditional, + float base, + float scale, + bool forward) + : Custom(stream, std::move(fallback)), + dims_(dims), + traditional_(traditional), + base_(base), + scale_(scale), + forward_(forward) {} + + static bool use_fallback(Stream s); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(RoPE) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple( + nullptr, dims_, traditional_, base_, scale_, forward_); + } + + private: + int dims_; + bool traditional_; + float base_; + float scale_; + bool forward_; +}; + +class ScaledDotProductAttention : public Custom { + public: + ScaledDotProductAttention( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool do_causal, + bool has_sinks, + bool output_logsumexp) + : Custom(stream, std::move(fallback)), + scale_(scale), + do_causal_(do_causal), + has_sinks_(has_sinks), + output_logsumexp_(output_logsumexp) {} + + static bool use_fallback( + const array& q, + const array& k, + const array& v, + bool has_mask, + bool has_arr_mask, + bool do_causal, + bool is_training, + bool output_logsumexp, + Stream s); + static bool supports_bool_mask(); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + bool is_equivalent(const Primitive& other) const override; + + DEFINE_NAME(ScaledDotProductAttention); + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple( + nullptr, scale_, do_causal_, has_sinks_, output_logsumexp_); + } + + private: + float scale_; + bool do_causal_; + bool has_sinks_; + bool output_logsumexp_; +}; + +class ScaledDotProductAttentionVJP : public Custom { + public: + ScaledDotProductAttentionVJP( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool do_causal, + bool has_sinks) + : Custom(stream, std::move(fallback)), + scale_(scale), + do_causal_(do_causal), + has_sinks_(has_sinks) {} + + static bool use_fallback(const array& q, Stream s); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("NYI"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(ScaledDotProductAttentionVJP); + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(nullptr, scale_, do_causal_, has_sinks_); + } + + private: + float scale_; + bool do_causal_; + bool has_sinks_; +}; + +class ConvertFP8 : public Primitive { + public: + explicit ConvertFP8(Stream stream, bool to_fp8) + : Primitive(stream), to_fp8_(to_fp8) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + const char* name() const override { + if (to_fp8_) { + return "ToFP8"; + } else { + return "FromFP8"; + } + } + bool state() const { + return to_fp8_; + }; + + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE(); + + private: + bool to_fp8_; +}; + +class Quantize : public Custom { + public: + explicit Quantize( + Stream stream, + std::function(std::vector)> fallback, + int group_size, + int bits, + QuantizationMode mode, + bool dequantize) + : Custom(stream, std::move(fallback)), + group_size_(group_size), + bits_(bits), + mode_(mode), + dequantize_(dequantize) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(Quantize); + + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(nullptr, group_size_, bits_, mode_, dequantize_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool dequantize_; +}; + +using ScalarArg = std::variant; + +class CustomKernel : public Primitive { + public: + CustomKernel( + Stream stream, + std::string name, + std::string source, + std::tuple grid, + std::tuple threadgroup, + std::vector> shape_infos, + bool ensure_row_contiguous, + std::optional init_value, + std::vector scalar_arguments, + bool is_precompiled, + int shared_memory) + : Primitive(stream), + name_(std::move(name)), + source_(std::move(source)), + grid_(grid), + threadgroup_(threadgroup), + shape_infos_(std::move(shape_infos)), + ensure_row_contiguous_(ensure_row_contiguous), + init_value_(init_value), + scalar_arguments_(std::move(scalar_arguments)), + is_precompiled_(is_precompiled), + shared_memory_(shared_memory) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("Custom kernels only run on GPU."); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(CustomKernel); + auto state() const { + return std::make_tuple( + name_, + source_, + grid_, + threadgroup_, + shape_infos_, + ensure_row_contiguous_, + init_value_, + scalar_arguments_, + is_precompiled_, + shared_memory_); + } + + private: + std::string name_; + std::string source_; + std::tuple grid_; + std::tuple threadgroup_; + std::vector> shape_infos_; + bool ensure_row_contiguous_; + std::optional init_value_; + std::vector scalar_arguments_; + bool is_precompiled_; + int shared_memory_; +}; + +} // namespace mlx::core::fast diff --git a/dist/include/mlx/fence.h b/dist/include/mlx/fence.h new file mode 100644 index 0000000..0ececdb --- /dev/null +++ b/dist/include/mlx/fence.h @@ -0,0 +1,39 @@ +// Copyright © 2024 Apple Inc. + +#include + +#include "mlx/array.h" + +namespace mlx::core { + +/* A fence to be used for synchronizing work between streams. + * + * Calls to `wait` wait in the given stream until all previous calls to update + * are complete on their given stream. + * + * The array passed to `update` is computed and visible after the call to + * `wait` returns. The array passed to `wait` will not be read until all + * previous calls to `update` have completed. + * + * Note, calls to `update` should always be from the same thread or explicitly + * synchronized so that they occur in sequence. Calls to `wait` can be on any + * thread. + * + * For the Metal back-end the fence supports slow (default) and fast mode. + * Fast mode requires setting the environment variable + * `MLX_METAL_FAST_SYNCH=1`. Fast mode also requires Metal 3.2+ (macOS 15+, + * iOS 18+). + */ +class Fence { + public: + Fence() {}; + explicit Fence(Stream stream); + + void update(Stream stream, const array& x, bool cross_device); + void wait(Stream stream, const array& x); + + private: + std::shared_ptr fence_{nullptr}; +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/fft.h b/dist/include/mlx/fft.h new file mode 100644 index 0000000..163e06b --- /dev/null +++ b/dist/include/mlx/fft.h @@ -0,0 +1,167 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "array.h" +#include "device.h" +#include "utils.h" + +namespace mlx::core::fft { + +/** Compute the n-dimensional Fourier Transform. */ +array fftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +array fftn(const array& a, const std::vector& axes, StreamOrDevice s = {}); +array fftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse Fourier Transform. */ +array ifftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +array ifftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array ifftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform. */ +inline array fft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return fftn(a, {n}, {axis}, s); +} +inline array fft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return fftn(a, {axis}, s); +} + +/** Compute the one-dimensional inverse Fourier Transform. */ +inline array ifft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return ifftn(a, {n}, {axis}, s); +} +inline array ifft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return ifftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform. */ +inline array fft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return fftn(a, n, axes, s); +} +inline array fft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return fftn(a, axes, s); +} + +/** Compute the two-dimensional inverse Fourier Transform. */ +inline array ifft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return ifftn(a, n, axes, s); +} +inline array ifft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return ifftn(a, axes, s); +} + +/** Compute the n-dimensional Fourier Transform on a real input. */ +array rfftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +array rfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array rfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the n-dimensional inverse of `rfftn`. */ +array irfftn( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}); +array irfftn( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); +array irfftn(const array& a, StreamOrDevice s = {}); + +/** Compute the one-dimensional Fourier Transform on a real input. */ +inline array rfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return rfftn(a, {n}, {axis}, s); +} +inline array rfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return rfftn(a, {axis}, s); +} +/** Compute the one-dimensional inverse of `rfft`. */ +inline array irfft(const array& a, int n, int axis, StreamOrDevice s = {}) { + return irfftn(a, {n}, {axis}, s); +} +inline array irfft(const array& a, int axis = -1, StreamOrDevice s = {}) { + return irfftn(a, {axis}, s); +} + +/** Compute the two-dimensional Fourier Transform on a real input. */ +inline array rfft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return rfftn(a, n, axes, s); +} +inline array rfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return rfftn(a, axes, s); +} + +/** Compute the two-dimensional inverse of `rfft2`. */ +inline array irfft2( + const array& a, + const Shape& n, + const std::vector& axes, + StreamOrDevice s = {}) { + return irfftn(a, n, axes, s); +} +inline array irfft2( + const array& a, + const std::vector& axes = {-2, -1}, + StreamOrDevice s = {}) { + return irfftn(a, axes, s); +} +/** Shift the zero-frequency component to the center of the spectrum. */ +array fftshift(const array& a, StreamOrDevice s = {}); + +/** Shift the zero-frequency component to the center of the spectrum along + * specified axes. */ +array fftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** The inverse of fftshift. */ +array ifftshift(const array& a, StreamOrDevice s = {}); + +/** The inverse of fftshift along specified axes. */ +array ifftshift( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +} // namespace mlx::core::fft diff --git a/dist/include/mlx/graph_utils.h b/dist/include/mlx/graph_utils.h new file mode 100644 index 0000000..fcbeef1 --- /dev/null +++ b/dist/include/mlx/graph_utils.h @@ -0,0 +1,66 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" + +namespace mlx::core { + +struct NodeNamer { + std::unordered_map names; + + const std::string& get_name(const array& x); + void set_name(const array& x, std::string n); +}; + +void print_graph( + std::ostream& os, + NodeNamer namer, + const std::vector& outputs); + +inline void print_graph(std::ostream& os, const std::vector& outputs) { + print_graph(os, NodeNamer{}, outputs); +} + +template > +inline void print_graph(std::ostream& os, Arrays&&... outputs) { + print_graph( + os, NodeNamer{}, std::vector{std::forward(outputs)...}); +} + +template > +inline void +print_graph(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { + print_graph( + os, + std::move(namer), + std::vector{std::forward(outputs)...}); +} + +void export_to_dot( + std::ostream& os, + NodeNamer namer, + const std::vector& outputs); + +inline void export_to_dot(std::ostream& os, const std::vector& outputs) { + export_to_dot(os, NodeNamer{}, outputs); +} + +template > +inline void export_to_dot(std::ostream& os, Arrays&&... outputs) { + export_to_dot( + os, NodeNamer{}, std::vector{std::forward(outputs)...}); +} + +template > +inline void +export_to_dot(std::ostream& os, NodeNamer namer, Arrays&&... outputs) { + export_to_dot( + os, + std::move(namer), + std::vector{std::forward(outputs)...}); +} + +} // namespace mlx::core diff --git a/dist/include/mlx/io.h b/dist/include/mlx/io.h new file mode 100644 index 0000000..23380b2 --- /dev/null +++ b/dist/include/mlx/io.h @@ -0,0 +1,61 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/array.h" +#include "mlx/io/load.h" +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace mlx::core { +using GGUFMetaData = + std::variant>; +using GGUFLoad = std::pair< + std::unordered_map, + std::unordered_map>; +using SafetensorsLoad = std::pair< + std::unordered_map, + std::unordered_map>; + +/** Save array to out stream in .npy format */ +void save(std::shared_ptr out_stream, array a); + +/** Save array to file in .npy format */ +void save(std::string file, array a); + +/** Load array from reader in .npy format */ +array load(std::shared_ptr in_stream, StreamOrDevice s = {}); + +/** Load array from file in .npy format */ +array load(std::string file, StreamOrDevice s = {}); + +/** Load array map from .safetensors file format */ +SafetensorsLoad load_safetensors( + std::shared_ptr in_stream, + StreamOrDevice s = {}); +SafetensorsLoad load_safetensors( + const std::string& file, + StreamOrDevice s = {}); + +void save_safetensors( + std::shared_ptr in_stream, + std::unordered_map, + std::unordered_map metadata = {}); +void save_safetensors( + std::string file, + std::unordered_map, + std::unordered_map metadata = {}); + +/** Load array map and metadata from .gguf file format */ + +GGUFLoad load_gguf(const std::string& file, StreamOrDevice s = {}); + +void save_gguf( + std::string file, + std::unordered_map array_map, + std::unordered_map meta_data = {}); + +} // namespace mlx::core diff --git a/dist/include/mlx/io/gguf.h b/dist/include/mlx/io/gguf.h new file mode 100644 index 0000000..fa5bc45 --- /dev/null +++ b/dist/include/mlx/io/gguf.h @@ -0,0 +1,20 @@ +// Copyright © 2023-2024 Apple Inc. +#pragma once + +#include "mlx/io.h" +#include "mlx/primitives.h" +#include "mlx/transforms.h" +#include "mlx/utils.h" + +extern "C" { +#include +} + +namespace mlx::core { + +Shape get_shape(const gguf_tensor& tensor); +void gguf_load_quantized( + std::unordered_map& a, + const gguf_tensor& tensor); + +} // namespace mlx::core diff --git a/dist/include/mlx/io/load.h b/dist/include/mlx/io/load.h new file mode 100644 index 0000000..0efcb36 --- /dev/null +++ b/dist/include/mlx/io/load.h @@ -0,0 +1,175 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +#include +#ifdef _MSC_VER +#include +#else +#include +#include +#endif + +#include "mlx/threadpool.h" + +// Strictly we need to operate on files in binary mode (to avoid \r getting +// automatically inserted), but every modern system except for Windows no +// longer differentiates between binary and text files and for them define +// the flag as no-op. +#ifndef O_BINARY +#define O_BINARY 0 +#endif + +namespace mlx::core { + +namespace io { + +ThreadPool& thread_pool(); + +class Reader { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() = 0; // tellp is non-const in iostream + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void read(char* data, size_t n) = 0; + virtual void read(char* data, size_t n, size_t offset) = 0; + virtual std::string label() const = 0; + virtual ~Reader() = default; +}; + +class Writer { + public: + virtual bool is_open() const = 0; + virtual bool good() const = 0; + virtual size_t tell() = 0; + virtual void seek( + int64_t off, + std::ios_base::seekdir way = std::ios_base::beg) = 0; + virtual void write(const char* data, size_t n) = 0; + virtual std::string label() const = 0; + virtual ~Writer() = default; +}; + +class ParallelFileReader : public Reader { + public: + explicit ParallelFileReader(std::string file_path) + : fd_(open(file_path.c_str(), O_RDONLY | O_BINARY)), + label_(std::move(file_path)) {} + + ~ParallelFileReader() override { + close(fd_); + } + + bool is_open() const override { + return fd_ > 0; + } + + bool good() const override { + return is_open(); + } + + size_t tell() override { + return lseek(fd_, 0, SEEK_CUR); + } + + // Warning: do not use this function from multiple threads as + // it advances the file descriptor + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + if (way == std::ios_base::beg) { + lseek(fd_, off, 0); + } else { + lseek(fd_, off, SEEK_CUR); + } + } + + // Warning: do not use this function from multiple threads as + // it advances the file descriptor + void read(char* data, size_t n) override; + + void read(char* data, size_t n, size_t offset) override; + + std::string label() const override { + return "file " + label_; + } + + private: + static constexpr size_t batch_size_ = 1 << 25; + static ThreadPool& thread_pool(); + int fd_; + std::string label_; +}; + +class FileWriter : public Writer { + public: + explicit FileWriter() {} + explicit FileWriter(std::string file_path) + : fd_(open( + file_path.c_str(), + O_CREAT | O_WRONLY | O_TRUNC | O_BINARY, + 0644)), + label_(std::move(file_path)) {} + + FileWriter(const FileWriter&) = delete; + FileWriter& operator=(const FileWriter&) = delete; + FileWriter(FileWriter&& other) { + std::swap(fd_, other.fd_); + } + + ~FileWriter() override { + if (fd_ != 0) { + close(fd_); + } + } + + bool is_open() const override { + return fd_ >= 0; + } + + bool good() const override { + return is_open(); + } + + size_t tell() override { + return lseek(fd_, 0, SEEK_CUR); + } + + void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) + override { + if (way == std::ios_base::beg) { + lseek(fd_, off, 0); + } else { + lseek(fd_, off, SEEK_CUR); + } + } + + void write(const char* data, size_t n) override { + while (n != 0) { + auto m = ::write(fd_, data, std::min(n, static_cast(INT32_MAX))); + if (m <= 0) { + std::ostringstream msg; + msg << "[write] Unable to write " << n << " bytes to file."; + throw std::runtime_error(msg.str()); + } + data += m; + n -= m; + } + } + + std::string label() const override { + return "file " + label_; + } + + private: + int fd_{0}; + std::string label_; +}; + +} // namespace io +} // namespace mlx::core diff --git a/dist/include/mlx/linalg.h b/dist/include/mlx/linalg.h new file mode 100644 index 0000000..0690fba --- /dev/null +++ b/dist/include/mlx/linalg.h @@ -0,0 +1,111 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/ops.h" +#include "mlx/stream.h" + +namespace mlx::core::linalg { + +/** + * Compute vector or matrix norms. + * + * - If axis and ord are both unspecified, computes the 2-norm of flatten(x). + * - If axis is not provided but ord is, then x must be either 1D or 2D. + * - If axis is provided, but ord is not, then the 2-norm (or Frobenius norm + * for matrices) is computed along the given axes. At most 2 axes can be + * specified. + * - If both axis and ord are provided, then the corresponding matrix or vector + * norm is computed. At most 2 axes can be specified. + */ +array norm( + const array& a, + const double ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const double ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::string& ord, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array norm( + const array& a, + const std::string& ord, + int axis, + bool keepdims = false, + StreamOrDevice s = {}) { + return norm(a, ord, std::vector{axis}, keepdims, s); +} +array norm( + const array& a, + const std::optional>& axis = std::nullopt, + bool keepdims = false, + StreamOrDevice s = {}); +inline array +norm(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {}) { + return norm(a, std::vector{axis}, keepdims, s); +} + +std::pair qr(const array& a, StreamOrDevice s = {}); + +std::vector +svd(const array& a, bool compute_uv, StreamOrDevice s /* = {} */); +inline std::vector svd(const array& a, StreamOrDevice s = {}) { + return svd(a, true, s); +} + +array inv(const array& a, StreamOrDevice s = {}); + +array tri_inv(const array& a, bool upper = false, StreamOrDevice s = {}); + +array cholesky(const array& a, bool upper = false, StreamOrDevice s = {}); + +array pinv(const array& a, StreamOrDevice s = {}); + +array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); + +std::vector lu(const array& a, StreamOrDevice s = {}); + +std::pair lu_factor(const array& a, StreamOrDevice s = {}); + +array solve(const array& a, const array& b, StreamOrDevice s = {}); + +array solve_triangular( + const array& a, + const array& b, + bool upper = false, + StreamOrDevice s = {}); + +/** + * Compute the cross product of two arrays along the given axis. + */ +array cross( + const array& a, + const array& b, + int axis = -1, + StreamOrDevice s = {}); + +std::pair eig(const array& a, StreamOrDevice s = {}); + +array eigvals(const array& a, StreamOrDevice s = {}); + +array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); + +std::pair +eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); + +} // namespace mlx::core::linalg diff --git a/dist/include/mlx/memory.h b/dist/include/mlx/memory.h new file mode 100644 index 0000000..8a26473 --- /dev/null +++ b/dist/include/mlx/memory.h @@ -0,0 +1,78 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include + +namespace mlx::core { + +/* Get the actively used memory in bytes. + * + * Note, this will not always match memory use reported by the system because + * it does not include cached memory buffers. + * */ +size_t get_active_memory(); + +/* Get the peak amount of used memory in bytes. + * + * The maximum memory used recorded from the beginning of the program + * execution or since the last call to reset_peak_memory. + * */ +size_t get_peak_memory(); + +/* Reset the peak memory to zero. + * */ +void reset_peak_memory(); + +/* Get the cache size in bytes. + * + * The cache includes memory not currently used that has not been returned + * to the system allocator. + * */ +size_t get_cache_memory(); + +/* Set the memory limit. + * The memory limit is a guideline for the maximum amount of memory to use + * during graph evaluation. If the memory limit is exceeded and there is no + * more RAM (including swap when available) allocations will result in an + * exception. + * + * When Metal is available the memory limit defaults to 1.5 times the maximum + * recommended working set size reported by the device. + * + * Returns the previous memory limit. + * */ +size_t set_memory_limit(size_t limit); + +/* Get the current memory limit. */ +size_t get_memory_limit(); + +/* Set the cache limit. + * If using more than the given limit, free memory will be reclaimed + * from the cache on the next allocation. To disable the cache, + * set the limit to 0. + * + * The cache limit defaults to the memory limit. + * + * Returns the previous cache limit. + * */ +size_t set_cache_limit(size_t limit); + +/* Clear the memory cache. */ +void clear_cache(); + +/* Set the wired size limit. + * + * Note, this function is only useful when using the Metal backend with + * macOS 15.0 or higher. + * + * The wired limit is the total size in bytes of memory that will be kept + * resident. The default value is ``0``. + * + * Setting a wired limit larger than system wired limit is an error. + * + * Returns the previous wired limit. + * */ +size_t set_wired_limit(size_t limit); + +} // namespace mlx::core diff --git a/dist/include/mlx/mlx.h b/dist/include/mlx/mlx.h new file mode 100644 index 0000000..dbc9014 --- /dev/null +++ b/dist/include/mlx/mlx.h @@ -0,0 +1,25 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/array.h" +#include "mlx/backend/cuda/cuda.h" +#include "mlx/backend/gpu/available.h" +#include "mlx/backend/metal/metal.h" +#include "mlx/compile.h" +#include "mlx/device.h" +#include "mlx/distributed/distributed.h" +#include "mlx/distributed/ops.h" +#include "mlx/einsum.h" +#include "mlx/export.h" +#include "mlx/fast.h" +#include "mlx/fft.h" +#include "mlx/io.h" +#include "mlx/linalg.h" +#include "mlx/memory.h" +#include "mlx/ops.h" +#include "mlx/random.h" +#include "mlx/stream.h" +#include "mlx/transforms.h" +#include "mlx/utils.h" +#include "mlx/version.h" diff --git a/dist/include/mlx/ops.h b/dist/include/mlx/ops.h new file mode 100644 index 0000000..ff92cbe --- /dev/null +++ b/dist/include/mlx/ops.h @@ -0,0 +1,1627 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace mlx::core { + +/** + * \defgroup ops Core array operations + * @{ + */ + +/** + * A 1D array of numbers starting at `start` (optional), + * stopping at stop, stepping by `step` (optional). */ +array arange( + double start, + double stop, + double step, + Dtype dtype, + StreamOrDevice s = {}); +array arange(double start, double stop, double step, StreamOrDevice s = {}); +array arange(double start, double stop, Dtype dtype, StreamOrDevice s = {}); +array arange(double start, double stop, StreamOrDevice s = {}); +array arange(double stop, Dtype dtype, StreamOrDevice s = {}); +array arange(double stop, StreamOrDevice s = {}); + +array arange(int start, int stop, int step, StreamOrDevice s = {}); +array arange(int start, int stop, StreamOrDevice s = {}); +array arange(int stop, StreamOrDevice s = {}); + +/** A 1D array of `num` evenly spaced numbers in the range `[start, stop]` */ +array linspace( + double start, + double stop, + int num = 50, + Dtype dtype = float32, + StreamOrDevice s = {}); + +/** Convert an array to the given data type. */ +array astype(array a, Dtype dtype, StreamOrDevice s = {}); + +/** Create a view of an array with the given shape and strides. */ +array as_strided( + array a, + Shape shape, + Strides strides, + size_t offset, + StreamOrDevice s = {}); + +/** Copy another array. */ +array copy(array a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with the given value(s). */ +array full(Shape shape, array vals, Dtype dtype, StreamOrDevice s = {}); +array full(Shape shape, array vals, StreamOrDevice s = {}); +template +array full(Shape shape, T val, Dtype dtype, StreamOrDevice s = {}) { + return full(std::move(shape), array(val, dtype), to_stream(s)); +} +template +array full(Shape shape, T val, StreamOrDevice s = {}) { + return full(std::move(shape), array(val), to_stream(s)); +} + +array full_like(const array& a, array vals, Dtype dtype, StreamOrDevice s = {}); +array full_like(const array& a, array vals, StreamOrDevice s = {}); +template +array full_like(const array& a, T val, Dtype dtype, StreamOrDevice s = {}) { + return full_like(a, array(val, dtype), dtype, to_stream(s)); +} +template +array full_like(const array& a, T val, StreamOrDevice s = {}) { + return full_like(a, array(val, a.dtype()), to_stream(s)); +} + +/** Fill an array of the given shape with zeros. */ +array zeros(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); +inline array zeros(const Shape& shape, StreamOrDevice s = {}) { + return zeros(shape, float32, s); +} +array zeros_like(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape with ones. */ +array ones(const Shape& shape, Dtype dtype, StreamOrDevice s = {}); +inline array ones(const Shape& shape, StreamOrDevice s = {}) { + return ones(shape, float32, s); +} +array ones_like(const array& a, StreamOrDevice s = {}); + +/** Fill an array of the given shape (n,m) with ones in the specified diagonal + * k, and zeros everywhere else. */ +array eye(int n, int m, int k, Dtype dtype, StreamOrDevice s = {}); +inline array eye(int n, Dtype dtype, StreamOrDevice s = {}) { + return eye(n, n, 0, dtype, s); +} +inline array eye(int n, int m, StreamOrDevice s = {}) { + return eye(n, m, 0, float32, s); +} +inline array eye(int n, int m, int k, StreamOrDevice s = {}) { + return eye(n, m, k, float32, s); +} +inline array eye(int n, StreamOrDevice s = {}) { + return eye(n, n, 0, float32, s); +} + +/** Create a square matrix of shape (n,n) of zeros, and ones in the major + * diagonal. */ +array identity(int n, Dtype dtype, StreamOrDevice s = {}); +inline array identity(int n, StreamOrDevice s = {}) { + return identity(n, float32, s); +} + +array tri(int n, int m, int k, Dtype type, StreamOrDevice s = {}); +inline array tri(int n, Dtype type, StreamOrDevice s = {}) { + return tri(n, n, 0, type, s); +} + +array tril(array x, int k = 0, StreamOrDevice s = {}); +array triu(array x, int k = 0, StreamOrDevice s = {}); + +/** Reshape an array to the given shape. */ +array reshape(const array& a, Shape shape, StreamOrDevice s = {}); + +/** Unflatten the axis to the given shape. */ +array unflatten(const array& a, int axis, Shape shape, StreamOrDevice s = {}); + +/** Flatten the dimensions in the range `[start_axis, end_axis]` . */ +array flatten( + const array& a, + int start_axis, + int end_axis = -1, + StreamOrDevice s = {}); + +/** Flatten the array to 1D. */ +array flatten(const array& a, StreamOrDevice s = {}); + +/** Multiply the array by the Hadamard matrix of corresponding size. */ +array hadamard_transform( + const array& a, + std::optional scale = std::nullopt, + StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axes. */ +array squeeze( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Remove singleton dimensions at the given axis. */ +array squeeze(const array& a, int axis, StreamOrDevice s = {}); + +/** Remove all singleton dimensions. */ +array squeeze(const array& a, StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axes. */ +array expand_dims( + const array& a, + const std::vector& axes, + StreamOrDevice s = {}); + +/** Add a singleton dimension at the given axis. */ +array expand_dims(const array& a, int axis, StreamOrDevice s = {}); + +/** Slice an array. */ +array slice( + const array& a, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); +inline array slice( + const array& a, + std::initializer_list start, + Shape stop, + Shape strides, + StreamOrDevice s = {}) { + return slice(a, Shape(start), std::move(stop), std::move(strides), s); +} + +/** Slice an array with a stride of 1 in each dimension. */ +array slice(const array& a, Shape start, Shape stop, StreamOrDevice s = {}); + +/** Slice an array with dynamic starting indices. */ +array slice( + const array& a, + const array& start, + std::vector axes, + Shape slice_size, + StreamOrDevice s = {}); + +/** Update a slice from the source array. */ +array slice_update( + const array& src, + const array& update, + Shape start, + Shape stop, + Shape strides, + StreamOrDevice s = {}); + +/** Update a slice from the source array with stride 1 in each dimension. */ +array slice_update( + const array& src, + const array& update, + Shape start, + Shape stop, + StreamOrDevice s = {}); + +/** Update a slice from the source array with dynamic starting indices. */ +array slice_update( + const array& src, + const array& update, + const array& start, + std::vector axes, + StreamOrDevice s = {}); + +/** Split an array into sub-arrays along a given axis. */ +std::vector +split(const array& a, int num_splits, int axis, StreamOrDevice s = {}); +std::vector split(const array& a, int num_splits, StreamOrDevice s = {}); +std::vector +split(const array& a, const Shape& indices, int axis, StreamOrDevice s = {}); +std::vector +split(const array& a, const Shape& indices, StreamOrDevice s = {}); + +/** A vector of coordinate arrays from coordinate vectors. */ +std::vector meshgrid( + const std::vector& arrays, + bool sparse = false, + const std::string& indexing = "xy", + StreamOrDevice s = {}); + +/** + * Clip (limit) the values in an array. + */ +array clip( + const array& a, + const std::optional& a_min = std::nullopt, + const std::optional& a_max = std::nullopt, + StreamOrDevice s = {}); + +/** Concatenate arrays along a given axis. */ +array concatenate(std::vector arrays, int axis, StreamOrDevice s = {}); +array concatenate(std::vector arrays, StreamOrDevice s = {}); + +/** Stack arrays along a new axis. */ +array stack(const std::vector& arrays, int axis, StreamOrDevice s = {}); +array stack(const std::vector& arrays, StreamOrDevice s = {}); + +/** Repeat an array along an axis. */ +array repeat(const array& arr, int repeats, int axis, StreamOrDevice s = {}); +array repeat(const array& arr, int repeats, StreamOrDevice s = {}); + +array tile(const array& arr, std::vector reps, StreamOrDevice s = {}); + +/** Permutes the dimensions according to the given axes. */ +array transpose(const array& a, std::vector axes, StreamOrDevice s = {}); +inline array transpose( + const array& a, + std::initializer_list axes, + StreamOrDevice s = {}) { + return transpose(a, std::vector(axes), s); +} + +/** Swap two axes of an array. */ +array swapaxes(const array& a, int axis1, int axis2, StreamOrDevice s = {}); + +/** Move an axis of an array. */ +array moveaxis( + const array& a, + int source, + int destination, + StreamOrDevice s = {}); + +/** Pad an array with a constant value */ +array pad( + const array& a, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); + +/** Pad an array with a constant value along all axes */ +array pad( + const array& a, + const std::vector>& pad_width, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); +array pad( + const array& a, + const std::pair& pad_width, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); +array pad( + const array& a, + int pad_width, + const array& pad_value = array(0), + const std::string& mode = "constant", + StreamOrDevice s = {}); + +/** Permutes the dimensions in reverse order. */ +array transpose(const array& a, StreamOrDevice s = {}); + +/** Broadcast an array to a given shape. */ +array broadcast_to(const array& a, const Shape& shape, StreamOrDevice s = {}); + +/** Broadcast a vector of arrays against one another. */ +std::vector broadcast_arrays( + const std::vector& inputs, + StreamOrDevice s = {}); + +/** Returns the bool array with (a == b) element-wise. */ +array equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator==(const array& a, const array& b) { + return equal(a, b); +} +template +array operator==(T a, const array& b) { + return equal(array(a), b); +} +template +array operator==(const array& a, T b) { + return equal(a, array(b)); +} + +/** Returns the bool array with (a != b) element-wise. */ +array not_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator!=(const array& a, const array& b) { + return not_equal(a, b); +} +template +array operator!=(T a, const array& b) { + return not_equal(array(a), b); +} +template +array operator!=(const array& a, T b) { + return not_equal(a, array(b)); +} + +/** Returns bool array with (a > b) element-wise. */ +array greater(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>(const array& a, const array& b) { + return greater(a, b); +} +template +array operator>(T a, const array& b) { + return greater(array(a), b); +} +template +array operator>(const array& a, T b) { + return greater(a, array(b)); +} + +/** Returns bool array with (a >= b) element-wise. */ +array greater_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator>=(const array& a, const array& b) { + return greater_equal(a, b); +} +template +array operator>=(T a, const array& b) { + return greater_equal(array(a), b); +} +template +array operator>=(const array& a, T b) { + return greater_equal(a, array(b)); +} + +/** Returns bool array with (a < b) element-wise. */ +array less(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<(const array& a, const array& b) { + return less(a, b); +} +template +array operator<(T a, const array& b) { + return less(array(a), b); +} +template +array operator<(const array& a, T b) { + return less(a, array(b)); +} + +/** Returns bool array with (a <= b) element-wise. */ +array less_equal(const array& a, const array& b, StreamOrDevice s = {}); +inline array operator<=(const array& a, const array& b) { + return less_equal(a, b); +} +template +array operator<=(T a, const array& b) { + return less_equal(array(a), b); +} +template +array operator<=(const array& a, T b) { + return less_equal(a, array(b)); +} + +/** True if two arrays have the same shape and elements. */ +array array_equal( + const array& a, + const array& b, + bool equal_nan, + StreamOrDevice s = {}); +inline array +array_equal(const array& a, const array& b, StreamOrDevice s = {}) { + return array_equal(a, b, false, s); +} + +array isnan(const array& a, StreamOrDevice s = {}); + +array isinf(const array& a, StreamOrDevice s = {}); + +array isfinite(const array& a, StreamOrDevice s = {}); + +array isposinf(const array& a, StreamOrDevice s = {}); + +array isneginf(const array& a, StreamOrDevice s = {}); + +/** Select from x or y depending on condition. */ +array where( + const array& condition, + const array& x, + const array& y, + StreamOrDevice s = {}); + +/** Replace NaN and infinities with finite numbers. */ +array nan_to_num( + const array& a, + float nan = 0.0f, + const std::optional posinf = std::nullopt, + const std::optional neginf = std::nullopt, + StreamOrDevice s = {}); + +/** True if all elements in the array are true (or non-zero). **/ +array all(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array all(const array& a, StreamOrDevice s = {}) { + return all(a, false, to_stream(s)); +} + +/** True if the two arrays are equal within the specified tolerance. */ +array allclose( + const array& a, + const array& b, + double rtol = 1e-5, + double atol = 1e-8, + bool equal_nan = false, + StreamOrDevice s = {}); + +/** Returns a boolean array where two arrays are element-wise equal within the + * specified tolerance. */ +array isclose( + const array& a, + const array& b, + double rtol = 1e-5, + double atol = 1e-8, + bool equal_nan = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axes. An output value is true + * if all the corresponding inputs are true. + **/ +array all( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if all the corresponding inputs are true. + **/ +array all( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** True if any elements in the array are true (or non-zero). **/ +array any(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array any(const array& a, StreamOrDevice s = {}) { + return any(a, false, to_stream(s)); +} + +/** + * Reduces the input along the given axes. An output value is true + * if any of the corresponding inputs are true. + **/ +array any( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** + * Reduces the input along the given axis. An output value is true + * if any of the corresponding inputs are true. + **/ +array any( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Sums the elements of an array. */ +array sum(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array sum(const array& a, StreamOrDevice s = {}) { + return sum(a, false, to_stream(s)); +} + +/** Sums the elements of an array along the given axes. */ +array sum( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Sums the elements of an array along the given axis. */ +array sum( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array. */ +array mean(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array mean(const array& a, StreamOrDevice s = {}) { + return mean(a, false, to_stream(s)); +} + +/** Computes the mean of the elements of an array along the given axes */ +array mean( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the mean of the elements of an array along the given axis */ +array mean( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the median of the elements of an array. */ +array median(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array median(const array& a, StreamOrDevice s = {}) { + return median(a, false, to_stream(s)); +} + +/** Computes the median of the elements of an array along the given axes */ +array median( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the median of the elements of an array along the given axis */ +array median( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Computes the variance of the elements of an array. */ +array var(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); +inline array var(const array& a, StreamOrDevice s = {}) { + return var(a, false, 0, to_stream(s)); +} + +/** Computes the variance of the elements of an array along the given + * axes */ +array var( + const array& a, + const std::vector& axes, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the variance of the elements of an array along the given + * axis */ +array var( + const array& a, + int axis, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the standard deviation of the elements of an array. */ +array std(const array& a, bool keepdims, int ddof = 0, StreamOrDevice s = {}); +inline array std(const array& a, StreamOrDevice s = {}) { + return std(a, false, 0, to_stream(s)); +} + +/** Computes the standard deviation of the elements of an array along the given + * axes */ +array std( + const array& a, + const std::vector& axes, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** Computes the standard deviation of the elements of an array along the given + * axis */ +array std( + const array& a, + int axis, + bool keepdims = false, + int ddof = 0, + StreamOrDevice s = {}); + +/** The product of all elements of the array. */ +array prod(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array prod(const array& a, StreamOrDevice s = {}) { + return prod(a, false, to_stream(s)); +} + +/** The product of the elements of an array along the given axes. */ +array prod( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The product of the elements of an array along the given axis. */ +array prod( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The maximum of all elements of the array. */ +array max(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array max(const array& a, StreamOrDevice s = {}) { + return max(a, false, to_stream(s)); +} + +/** The maximum of the elements of an array along the given axes. */ +array max( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The maximum of the elements of an array along the given axis. */ +array max( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The minimum of all elements of the array. */ +array min(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array min(const array& a, StreamOrDevice s = {}) { + return min(a, false, to_stream(s)); +} + +/** The minimum of the elements of an array along the given axes. */ +array min( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The minimum of the elements of an array along the given axis. */ +array min( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns the index of the minimum value in the array. */ +array argmin(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmin(const array& a, StreamOrDevice s = {}) { + return argmin(a, false, s); +} + +/** Returns the indices of the minimum values along a given axis. */ +array argmin( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns the index of the maximum value in the array. */ +array argmax(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array argmax(const array& a, StreamOrDevice s = {}) { + return argmax(a, false, s); +} + +/** Returns the indices of the maximum values along a given axis. */ +array argmax( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Returns a sorted copy of the flattened array. */ +array sort(const array& a, StreamOrDevice s = {}); + +/** Returns a sorted copy of the array along a given axis. */ +array sort(const array& a, int axis, StreamOrDevice s = {}); + +/** Returns indices that sort the flattened array. */ +array argsort(const array& a, StreamOrDevice s = {}); + +/** Returns indices that sort the array along a given axis. */ +array argsort(const array& a, int axis, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the flattened array + * such that the smaller kth elements are first. + **/ +array partition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns a partitioned copy of the array along a given axis + * such that the smaller kth elements are first. + **/ +array partition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** + * Returns indices that partition the flattened array + * such that the smaller kth elements are first. + **/ +array argpartition(const array& a, int kth, StreamOrDevice s = {}); + +/** + * Returns indices that partition the array along a given axis + * such that the smaller kth elements are first. + **/ +array argpartition(const array& a, int kth, int axis, StreamOrDevice s = {}); + +/** Returns topk elements of the flattened array. */ +array topk(const array& a, int k, StreamOrDevice s = {}); + +/** Returns topk elements of the array along a given axis. */ +array topk(const array& a, int k, int axis, StreamOrDevice s = {}); + +/** Cumulative logsumexp of an array. */ +array logcumsumexp( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative logsumexp of an array along the given axis. */ +array logcumsumexp( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** The logsumexp of all elements of the array. */ +array logsumexp(const array& a, bool keepdims, StreamOrDevice s = {}); +inline array logsumexp(const array& a, StreamOrDevice s = {}) { + return logsumexp(a, false, to_stream(s)); +} + +/** The logsumexp of the elements of an array along the given axes. */ +array logsumexp( + const array& a, + const std::vector& axes, + bool keepdims = false, + StreamOrDevice s = {}); + +/** The logsumexp of the elements of an array along the given axis. */ +array logsumexp( + const array& a, + int axis, + bool keepdims = false, + StreamOrDevice s = {}); + +/** Absolute value of elements in an array. */ +array abs(const array& a, StreamOrDevice s = {}); + +/** Negate an array. */ +array negative(const array& a, StreamOrDevice s = {}); +array operator-(const array& a); + +/** The sign of the elements in an array. */ +array sign(const array& a, StreamOrDevice s = {}); + +/** Logical not of an array */ +array logical_not(const array& a, StreamOrDevice s = {}); + +/** Logical and of two arrays */ +array logical_and(const array& a, const array& b, StreamOrDevice s = {}); +array operator&&(const array& a, const array& b); + +/** Logical or of two arrays */ +array logical_or(const array& a, const array& b, StreamOrDevice s = {}); +array operator||(const array& a, const array& b); + +/** The reciprocal (1/x) of the elements in an array. */ +array reciprocal(const array& a, StreamOrDevice s = {}); + +/** Add two arrays. */ +array add(const array& a, const array& b, StreamOrDevice s = {}); +array operator+(const array& a, const array& b); +template +array operator+(T a, const array& b) { + return add(array(a), b); +} +template +array operator+(const array& a, T b) { + return add(a, array(b)); +} + +/** Subtract two arrays. */ +array subtract(const array& a, const array& b, StreamOrDevice s = {}); +array operator-(const array& a, const array& b); +template +array operator-(T a, const array& b) { + return subtract(array(a), b); +} +template +array operator-(const array& a, T b) { + return subtract(a, array(b)); +} + +/** Multiply two arrays. */ +array multiply(const array& a, const array& b, StreamOrDevice s = {}); +array operator*(const array& a, const array& b); +template +array operator*(T a, const array& b) { + return multiply(array(a), b); +} +template +array operator*(const array& a, T b) { + return multiply(a, array(b)); +} + +/** Divide two arrays. */ +array divide(const array& a, const array& b, StreamOrDevice s = {}); +array operator/(const array& a, const array& b); +array operator/(double a, const array& b); +array operator/(const array& a, double b); + +/** Compute the element-wise quotient and remainder. */ +std::vector +divmod(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute integer division. Equivalent to doing floor(a / x). */ +array floor_divide(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute the element-wise remainder of division */ +array remainder(const array& a, const array& b, StreamOrDevice s = {}); +array operator%(const array& a, const array& b); +template +array operator%(T a, const array& b) { + return remainder(array(a), b); +} +template +array operator%(const array& a, T b) { + return remainder(a, array(b)); +} + +/** Element-wise maximum between two arrays. */ +array maximum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise minimum between two arrays. */ +array minimum(const array& a, const array& b, StreamOrDevice s = {}); + +/** Floor the element of an array. **/ +array floor(const array& a, StreamOrDevice s = {}); + +/** Ceil the element of an array. **/ +array ceil(const array& a, StreamOrDevice s = {}); + +/** Square the elements of an array. */ +array square(const array& a, StreamOrDevice s = {}); + +/** Exponential of the elements of an array. */ +array exp(const array& a, StreamOrDevice s = {}); + +/** Sine of the elements of an array */ +array sin(const array& a, StreamOrDevice s = {}); + +/** Cosine of the elements of an array */ +array cos(const array& a, StreamOrDevice s = {}); + +/** Tangent of the elements of an array */ +array tan(const array& a, StreamOrDevice s = {}); + +/** Arc Sine of the elements of an array */ +array arcsin(const array& a, StreamOrDevice s = {}); + +/** Arc Cosine of the elements of an array */ +array arccos(const array& a, StreamOrDevice s = {}); + +/** Arc Tangent of the elements of an array */ +array arctan(const array& a, StreamOrDevice s = {}); + +/** Inverse tangent of the ratio of two arrays */ +array arctan2(const array& a, const array& b, StreamOrDevice s = {}); + +/** Hyperbolic Sine of the elements of an array */ +array sinh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Cosine of the elements of an array */ +array cosh(const array& a, StreamOrDevice s = {}); + +/** Hyperbolic Tangent of the elements of an array */ +array tanh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Sine of the elements of an array */ +array arcsinh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Cosine of the elements of an array */ +array arccosh(const array& a, StreamOrDevice s = {}); + +/** Inverse Hyperbolic Tangent of the elements of an array */ +array arctanh(const array& a, StreamOrDevice s = {}); + +/** Convert the elements of an array from Radians to Degrees **/ +array degrees(const array& a, StreamOrDevice s = {}); + +/** Convert the elements of an array from Degrees to Radians **/ +array radians(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of the elements of an array. */ +array log(const array& a, StreamOrDevice s = {}); + +/** Log base 2 of the elements of an array. */ +array log2(const array& a, StreamOrDevice s = {}); + +/** Log base 10 of the elements of an array. */ +array log10(const array& a, StreamOrDevice s = {}); + +/** Natural logarithm of one plus elements in the array: `log(1 + a)`. */ +array log1p(const array& a, StreamOrDevice s = {}); + +/** Log-add-exp of one elements in the array: `log(exp(a) + exp(b))`. */ +array logaddexp(const array& a, const array& b, StreamOrDevice s = {}); + +/** Element-wise logistic sigmoid of the array: `1 / (1 + exp(-x)`. */ +array sigmoid(const array& a, StreamOrDevice s = {}); + +/** Computes the error function of the elements of an array. */ +array erf(const array& a, StreamOrDevice s = {}); + +/** Computes the inverse error function of the elements of an array. */ +array erfinv(const array& a, StreamOrDevice s = {}); + +/** Computes the expm1 function of the elements of an array. */ +array expm1(const array& a, StreamOrDevice s = {}); + +/** Stop the flow of gradients. */ +array stop_gradient(const array& a, StreamOrDevice s = {}); + +/** Round a floating point number */ +array round(const array& a, int decimals, StreamOrDevice s = {}); +inline array round(const array& a, StreamOrDevice s = {}) { + return round(a, 0, s); +} + +/** Matrix-matrix multiplication. */ +array matmul(const array& a, const array& b, StreamOrDevice s = {}); + +/** Gather array entries given indices and slices */ +array gather( + const array& a, + const std::vector& indices, + const std::vector& axes, + const Shape& slice_sizes, + StreamOrDevice s = {}); +inline array gather( + const array& a, + const array& indices, + int axis, + const Shape& slice_sizes, + StreamOrDevice s = {}) { + return gather(a, {indices}, std::vector{axis}, slice_sizes, s); +} + +/** Compute the Kronecker product of two arrays. */ +array kron(const array& a, const array& b, StreamOrDevice s = {}); + +/** Take array slices at the given indices of the specified axis. */ +array take( + const array& a, + const array& indices, + int axis, + StreamOrDevice s = {}); +array take(const array& a, int index, int axis, StreamOrDevice s = {}); + +/** Take array entries at the given indices treating the array as flattened. */ +array take(const array& a, const array& indices, StreamOrDevice s = {}); +array take(const array& a, int index, StreamOrDevice s = {}); + +/** Take array entries given indices along the axis */ +array take_along_axis( + const array& a, + const array& indices, + int axis, + StreamOrDevice s = {}); + +/** Put the values into the array at the given indices along the axis */ +array put_along_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s = {}); + +/** Add the values into the array at the given indices along the axis */ +array scatter_add_axis( + const array& a, + const array& indices, + const array& values, + int axis, + StreamOrDevice s = {}); + +/** Scatter updates to the given indices. + * + * The parameters ``indices`` and ``axes`` determine the locations of ``a`` + * that are updated with the values in ``updates``. Assuming 1-d ``indices`` + * for simplicity, ``indices[i]`` are the indices on axis ``axes[i]`` to which + * the values in ``updates`` will be applied. Note each array in + * ``indices`` is assigned to a corresponding axis and hence ``indices.size() == + * axes.size()``. If an index/axis pair is not provided then indices along that + * axis are assumed to be zero. + * + * Note the rank of ``updates`` must be equal to the sum of the rank of the + * broadcasted ``indices`` and the rank of ``a``. In other words, assuming the + * arrays in ``indices`` have the same shape, ``updates.ndim() == + * indices[0].ndim() + a.ndim()``. The leading dimensions of ``updates`` + * correspond to the indices, and the remaining ``a.ndim()`` dimensions are the + * values that will be applied to the given location in ``a``. + * + * For example: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = array({2}); + * auto updates = reshape(arange(1, 3, float32), {1, 1, 2}); + * std::vector axes{0}; + * + * auto out = scatter(in, {indices}, updates, axes); + * @endcode + * + * will produce: + * + * @code + * array([[0, 0, 0, 0], + * [0, 0, 0, 0], + * [1, 2, 0, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * This scatters the two-element row vector ``[1, 2]`` starting at the ``(2, + * 0)`` position of ``a``. + * + * Adding another element to ``indices`` will scatter into another location of + * ``a``. We also have to add an another update for the new index: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = array({2, 0}); + * auto updates = reshape(arange(1, 5, float32), {2, 1, 2}); + * std::vector axes{0}; + * + * auto out = scatter(in, {indices}, updates, axes): + * @endcode + * + * will produce: + * + * @code + * array([[3, 4, 0, 0], + * [0, 0, 0, 0], + * [1, 2, 0, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * To control the scatter location on an additional axis, add another index + * array to ``indices`` and another axis to ``axes``: + * + * @code + * auto in = zeros({4, 4}, float32); + * auto indices = std::vector{array({2, 0}), array({1, 2})}; + * auto updates = reshape(arange(1, 5, float32), {2, 1, 2}); + * std::vector axes{0, 1}; + * + * auto out = scatter(in, indices, updates, axes); + * @endcode + * + * will produce: + * + * @code + * array([[0, 0, 3, 4], + * [0, 0, 0, 0], + * [0, 1, 2, 0], + * [0, 0, 0, 0]], dtype=float32) + * @endcode + * + * Items in indices are broadcasted together. This means: + * + * @code + * auto indices = std::vector{array({2, 0}), array({1})}; + * @endcode + * + * is equivalent to: + * + * @code + * auto indices = std::vector{array({2, 0}), array({1, 1})}; + * @endcode + * + * Note, ``scatter`` does not perform bounds checking on the indices and + * updates. Out-of-bounds accesses on ``a`` are undefined and typically result + * in unintended or invalid memory writes. + */ +array scatter( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and add updates to given indices */ +array scatter_add( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_add( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_add(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and prod updates to given indices */ +array scatter_prod( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_prod( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_prod(a, {indices}, updates, std::vector{axis}, s); +} + +/** Scatter and max updates to given linear indices */ +array scatter_max( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_max( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_max(a, {indices}, updates, std::vector{axis}, s); +} +/** Scatter and min updates to given linear indices */ +array scatter_min( + const array& a, + const std::vector& indices, + const array& updates, + const std::vector& axes, + StreamOrDevice s = {}); +inline array scatter_min( + const array& a, + const array& indices, + const array& updates, + int axis, + StreamOrDevice s = {}) { + return scatter_min(a, {indices}, updates, std::vector{axis}, s); +} + +array masked_scatter( + const array& a, + const array& mask, + const array& src, + StreamOrDevice s = {}); + +/** Square root the elements of an array. */ +array sqrt(const array& a, StreamOrDevice s = {}); + +/** Square root and reciprocal the elements of an array. */ +array rsqrt(const array& a, StreamOrDevice s = {}); + +/** Softmax of an array. */ +array softmax( + const array& a, + const std::vector& axes, + bool precise = false, + StreamOrDevice s = {}); + +/** Softmax of an array. */ +array softmax(const array& a, bool precise = false, StreamOrDevice s = {}); + +/** Softmax of an array. */ +inline array +softmax(const array& a, int axis, bool precise = false, StreamOrDevice s = {}) { + return softmax(a, std::vector{axis}, precise, s); +} + +/** Raise elements of a to the power of b element-wise */ +array power(const array& a, const array& b, StreamOrDevice s = {}); + +/** Cumulative sum of an array. */ +array cumsum( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative sum of an array along the given axis. */ +array cumsum( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative product of an array. */ +array cumprod( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative product of an array along the given axis. */ +array cumprod( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative max of an array. */ +array cummax( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative max of an array along the given axis. */ +array cummax( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative min of an array. */ +array cummin( + const array& a, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** Cumulative min of an array along the given axis. */ +array cummin( + const array& a, + int axis, + bool reverse = false, + bool inclusive = true, + StreamOrDevice s = {}); + +/** General convolution with a filter */ +array conv_general( + array input, + array weight, + std::vector stride = {}, + std::vector padding_lo = {}, + std::vector padding_hi = {}, + std::vector kernel_dilation = {}, + std::vector input_dilation = {}, + int groups = 1, + bool flip = false, + StreamOrDevice s = {}); + +/** General convolution with a filter */ +inline array conv_general( + const array& input, + const array& weight, + std::vector stride = {}, + std::vector padding = {}, + std::vector kernel_dilation = {}, + std::vector input_dilation = {}, + int groups = 1, + bool flip = false, + StreamOrDevice s = {}) { + return conv_general( + /* const array& input = */ input, + /* const array& weight = */ weight, + /* std::vector stride = */ stride, + /* std::vector padding_lo = */ padding, + /* std::vector padding_hi = */ padding, + /* std::vector kernel_dilation = */ kernel_dilation, + /* std::vector input_dilation = */ input_dilation, + /* int groups = */ groups, + /* bool flip = */ flip, + /* StreamOrDevice s = */ s); +} + +/** 1D convolution with a filter */ +array conv1d( + const array& input, + const array& weight, + int stride = 1, + int padding = 0, + int dilation = 1, + int groups = 1, + StreamOrDevice s = {}); + +/** 2D convolution with a filter */ +array conv2d( + const array& input, + const array& weight, + const std::pair& stride = {1, 1}, + const std::pair& padding = {0, 0}, + const std::pair& dilation = {1, 1}, + int groups = 1, + StreamOrDevice s = {}); + +/** 3D convolution with a filter */ +array conv3d( + const array& input, + const array& weight, + const std::tuple& stride = {1, 1, 1}, + const std::tuple& padding = {0, 0, 0}, + const std::tuple& dilation = {1, 1, 1}, + int groups = 1, + StreamOrDevice s = {}); + +/** 1D transposed convolution with a filter */ +array conv_transpose1d( + const array& input, + const array& weight, + int stride = 1, + int padding = 0, + int dilation = 1, + int output_padding = 0, + int groups = 1, + StreamOrDevice s = {}); + +/** 2D transposed convolution with a filter */ +array conv_transpose2d( + const array& input, + const array& weight, + const std::pair& stride = {1, 1}, + const std::pair& padding = {0, 0}, + const std::pair& dilation = {1, 1}, + const std::pair& output_padding = {0, 0}, + int groups = 1, + StreamOrDevice s = {}); + +/** 3D transposed convolution with a filter */ +array conv_transpose3d( + const array& input, + const array& weight, + const std::tuple& stride = {1, 1, 1}, + const std::tuple& padding = {0, 0, 0}, + const std::tuple& dilation = {1, 1, 1}, + const std::tuple& output_padding = {0, 0, 0}, + int groups = 1, + StreamOrDevice s = {}); + +/** Quantized matmul multiplies x with a quantized matrix w*/ +array quantized_matmul( + array x, + array w, + array scales, + std::optional biases = std::nullopt, + bool transpose = true, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + StreamOrDevice s = {}); + +/** Quantize a matrix along its last axis */ +std::vector quantize( + const array& w, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + StreamOrDevice s = {}); + +/** Dequantize a matrix produced by quantize() */ +array dequantize( + const array& w, + const array& scales, + const std::optional& biases = std::nullopt, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + std::optional dtype = std::nullopt, + StreamOrDevice s = {}); + +array qqmm( + array x, // input activations + array w, // maybe quantized weights + std::optional w_scales = std::nullopt, // optional scales if w is + // quantized + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "nvfp4", + StreamOrDevice s = {}); + +/** Convert an E4M3 float8 to the given floating point dtype. */ +array from_fp8(array x, Dtype dtype, StreamOrDevice s = {}); + +/** Convert a floating point matrix to E4M3 float8. */ +array to_fp8(array x, StreamOrDevice s = {}); + +/** Compute matrix products with matrix-level gather. */ +array gather_qmm( + const array& x, + const array& w, + const array& scales, + const std::optional& biases = std::nullopt, + std::optional lhs_indices = std::nullopt, + std::optional rhs_indices = std::nullopt, + bool transpose = true, + std::optional group_size = std::nullopt, + std::optional bits = std::nullopt, + const std::string& mode = "affine", + bool sorted_indices = false, + StreamOrDevice s = {}); + +/** Returns a contraction of a and b over multiple dimensions. */ +array tensordot( + const array& a, + const array& b, + const int axis = 2, + StreamOrDevice s = {}); + +array tensordot( + const array& a, + const array& b, + const std::vector& axes_a, + const std::vector& axes_b, + StreamOrDevice s = {}); + +/** Compute the outer product of two vectors. */ +array outer(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute the inner product of two vectors. */ +array inner(const array& a, const array& b, StreamOrDevice s = {}); + +/** Compute D = beta * C + alpha * (A @ B) */ +array addmm( + array c, + array a, + array b, + const float& alpha = 1.f, + const float& beta = 1.f, + StreamOrDevice s = {}); + +/** Compute matrix product with block masking */ +array block_masked_mm( + array a, + array b, + int block_size, + std::optional mask_out = std::nullopt, + std::optional mask_lhs = std::nullopt, + std::optional mask_rhs = std::nullopt, + StreamOrDevice s = {}); + +/** Compute matrix product with matrix-level gather */ +array gather_mm( + array a, + array b, + std::optional lhs_indices = std::nullopt, + std::optional rhs_indices = std::nullopt, + bool sorted_indices = false, + StreamOrDevice s = {}); + +/** + * Compute a matrix product but segment the inner dimension and write the + * result separately for each segment. + */ +array segmented_mm(array a, array b, array segments, StreamOrDevice s = {}); + +/** Extract a diagonal or construct a diagonal array */ +array diagonal( + const array& a, + int offset = 0, + int axis1 = 0, + int axis2 = 1, + StreamOrDevice s = {}); + +/** Extract diagonal from a 2d array or create a diagonal matrix. */ +array diag(const array& a, int k = 0, StreamOrDevice s = {}); + +/** Return the sum along a specified diagonal in the given array. */ +array trace( + const array& a, + int offset, + int axis1, + int axis2, + Dtype dtype, + StreamOrDevice s = {}); +array trace( + const array& a, + int offset, + int axis1, + int axis2, + StreamOrDevice s = {}); +array trace(const array& a, StreamOrDevice s = {}); + +/** + * Implements the identity function but allows injecting dependencies to other + * arrays. This ensures that these other arrays will have been computed + * when the outputs of this function are computed. + */ +std::vector depends( + const std::vector& inputs, + const std::vector& dependencies); + +/** convert an array to an atleast ndim array */ +array atleast_1d(const array& a, StreamOrDevice s = {}); +std::vector atleast_1d( + const std::vector& a, + StreamOrDevice s = {}); +array atleast_2d(const array& a, StreamOrDevice s = {}); +std::vector atleast_2d( + const std::vector& a, + StreamOrDevice s = {}); +array atleast_3d(const array& a, StreamOrDevice s = {}); +std::vector atleast_3d( + const std::vector& a, + StreamOrDevice s = {}); + +/** + * Extract the number of elements along some axes as a scalar array. Used to + * allow shape dependent shapeless compilation (pun intended). + */ +array number_of_elements( + const array& a, + std::vector axes, + bool inverted, + Dtype dtype = int32, + StreamOrDevice s = {}); + +array conjugate(const array& a, StreamOrDevice s = {}); + +/** Bitwise and. */ +array bitwise_and(const array& a, const array& b, StreamOrDevice s = {}); +array operator&(const array& a, const array& b); + +/** Bitwise inclusive or. */ +array bitwise_or(const array& a, const array& b, StreamOrDevice s = {}); +array operator|(const array& a, const array& b); + +/** Bitwise exclusive or. */ +array bitwise_xor(const array& a, const array& b, StreamOrDevice s = {}); +array operator^(const array& a, const array& b); + +/** Shift bits to the left. */ +array left_shift(const array& a, const array& b, StreamOrDevice s = {}); +array operator<<(const array& a, const array& b); + +/** Shift bits to the right. */ +array right_shift(const array& a, const array& b, StreamOrDevice s = {}); +array operator>>(const array& a, const array& b); + +/** Invert the bits. */ +array bitwise_invert(const array& a, StreamOrDevice s = {}); +array operator~(const array& a); + +array view(const array& a, const Dtype& dtype, StreamOrDevice s = {}); + +/** Roll elements along an axis and introduce them on the other side */ +array roll(const array& a, int shift, StreamOrDevice s = {}); +array roll(const array& a, const Shape& shift, StreamOrDevice s = {}); +array roll(const array& a, int shift, int axis, StreamOrDevice s = {}); +array roll( + const array& a, + int shift, + const std::vector& axes, + StreamOrDevice s = {}); +array roll(const array& a, const Shape& shift, int axis, StreamOrDevice s = {}); +array roll( + const array& a, + const Shape& shift, + const std::vector& axes, + StreamOrDevice s = {}); + +/* The real part of a complex array. */ +array real(const array& a, StreamOrDevice s = {}); + +/* The imaginary part of a complex array. */ +array imag(const array& a, StreamOrDevice s = {}); + +/* Ensure the array's underlying memory is contiguous. */ +array contiguous( + const array& a, + bool allow_col_major = false, + StreamOrDevice s = {}); + +/** @} */ + +} // namespace mlx::core diff --git a/dist/include/mlx/primitives.h b/dist/include/mlx/primitives.h new file mode 100644 index 0000000..c3ce00f --- /dev/null +++ b/dist/include/mlx/primitives.h @@ -0,0 +1,2524 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/io/load.h" +#include "mlx/stream.h" + +#define DEFINE_VMAP() \ + virtual std::pair, std::vector> vmap( \ + const std::vector& inputs, const std::vector& axes) \ + override; + +#define DEFINE_GRADS() \ + std::vector jvp( \ + const std::vector& primals, \ + const std::vector& tangents, \ + const std::vector& argnums) override; \ + \ + std::vector vjp( \ + const std::vector& primals, \ + const std::vector& cotangents, \ + const std::vector& argnums, \ + const std::vector& outputs) override; + +#define DEFINE_NAME(PRIMITIVE) \ + const char* name() const override { \ + return #PRIMITIVE; \ + } + +#define DEFINE_DEFAULT_IS_EQUIVALENT() \ + bool is_equivalent(const Primitive& other) const override { \ + return true; \ + } + +#define DEFINE_INPUT_OUTPUT_SHAPE() \ + std::vector output_shapes(const std::vector& inputs) \ + override { \ + return {inputs[0].shape()}; \ + } + +namespace mlx::core { + +// Abstract base class +class Primitive { + public: + explicit Primitive(Stream stream) : stream_(stream) {} + + /** The device the primitive will run on. */ + const Device& device() { + return stream().device; + } + + /** The stream the primitive will run on. */ + const Stream& stream() { + return stream_; + } + + /** + * A primitive must know how to evaluate itself on + * the CPU/GPU for the given inputs and populate the output arrays. + * + * To avoid unnecessary allocations, the evaluation function + * is responsible for allocating space for the array. + */ + virtual void eval_cpu( + const std::vector& inputs, + std::vector& outputs) = 0; + virtual void eval_gpu( + const std::vector& inputs, + std::vector& outputs) = 0; + + /** + * The Jacobian-vector product. + */ + virtual std::vector jvp( + const std::vector& primals, + const std::vector& tangents, + const std::vector& argnums); + + /** + * The vector-Jacobian product. + */ + virtual std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs); + + /** + * The primitive must know how to vectorize itself across + * the given axes. The output is a pair containing the output arrays + * representing the vectorized computation and the axes which + * corresponds to the vectorized dimensions of each output. + */ + virtual std::pair, std::vector> vmap( + const std::vector& inputs, + const std::vector& axes); + + /** Get the name of primitive. */ + virtual const char* name() const = 0; + + /** Equivalence check defaults to false unless overridden by the primitive */ + virtual bool is_equivalent(const Primitive& other) const { + return false; + } + + /** Get the output shapes of the primitive. This is not required to be + * implemented by derived classes, in which case it will throw. */ + virtual std::vector output_shapes(const std::vector& inputs); + + virtual ~Primitive() = default; + Primitive(const Primitive& other) = delete; + Primitive(Primitive&& other) = delete; + Primitive& operator=(const Primitive& other) = delete; + Primitive& operator=(Primitive&& other) = delete; + + private: + // Every primitive stores the stream it should run in + Stream stream_; +}; + +class UnaryPrimitive : public Primitive { + /** + * An abstract base class for a primitive with a single output. + */ + public: + explicit UnaryPrimitive(Stream stream) : Primitive(stream) {} + + virtual void eval_cpu(const std::vector& inputs, array& output) = 0; + virtual void eval_gpu(const std::vector& inputs, array& output) = 0; + + inline void eval_cpu( + const std::vector& inputs, + std::vector& outputs) override { + eval_cpu(inputs, outputs[0]); + } + inline void eval_gpu( + const std::vector& inputs, + std::vector& outputs) override { + eval_gpu(inputs, outputs[0]); + } + + virtual ~UnaryPrimitive() = default; + UnaryPrimitive(const UnaryPrimitive& other) = delete; + UnaryPrimitive(UnaryPrimitive&& other) = delete; + UnaryPrimitive& operator=(const UnaryPrimitive& other) = delete; + UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; +}; + +enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; + +std::string quantization_mode_to_string(QuantizationMode mode); +QuantizationMode string_to_quantization_mode( + const std::string& mode, + std::string_view error_tag = ""); + +class Abs : public UnaryPrimitive { + public: + explicit Abs(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Abs) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Add : public UnaryPrimitive { + public: + explicit Add(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Add) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class AddMM : public UnaryPrimitive { + public: + explicit AddMM(Stream stream, float alpha, float beta) + : UnaryPrimitive(stream), alpha_(alpha), beta_(beta) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_VMAP() + DEFINE_NAME(AddMM) + + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {alpha_, beta_}; + }; + + private: + const float alpha_; + const float beta_; +}; + +class Arange : public UnaryPrimitive { + public: + explicit Arange(Stream stream, double start, double stop, double step) + : UnaryPrimitive(stream), start_(start), stop_(stop), step_(step) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(Arange) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::tuple state() const { + return {start_, stop_, step_}; + }; + + private: + double start_; + double stop_; + double step_; +}; + +class ArcCos : public UnaryPrimitive { + public: + explicit ArcCos(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcCos) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcCosh : public UnaryPrimitive { + public: + explicit ArcCosh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcCosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcSin : public UnaryPrimitive { + public: + explicit ArcSin(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcSin) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcSinh : public UnaryPrimitive { + public: + explicit ArcSinh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcSinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcTan : public UnaryPrimitive { + public: + explicit ArcTan(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcTan) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcTan2 : public UnaryPrimitive { + public: + explicit ArcTan2(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcTan2) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArcTanh : public UnaryPrimitive { + public: + explicit ArcTanh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArcTanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ArgPartition : public UnaryPrimitive { + public: + explicit ArgPartition(Stream stream, int kth, int axis) + : UnaryPrimitive(stream), kth_(kth), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArgPartition) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {kth_, axis_}; + }; + + private: + int kth_; + int axis_; +}; + +class ArgReduce : public UnaryPrimitive { + public: + enum ReduceType { + ArgMin, + ArgMax, + }; + + explicit ArgReduce(Stream stream, ReduceType reduce_type, int axis) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArgReduce) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::pair state() const { + return {reduce_type_, axis_}; + }; + + private: + ReduceType reduce_type_; + int axis_; +}; + +class ArgSort : public UnaryPrimitive { + public: + explicit ArgSort(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ArgSort) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + int state() const { + return axis_; + }; + + private: + int axis_; +}; + +class AsType : public UnaryPrimitive { + public: + explicit AsType(Stream stream, Dtype dtype) + : UnaryPrimitive(stream), dtype_(dtype) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(AsType) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + Dtype state() const { + return dtype_; + }; + + private: + Dtype dtype_; +}; + +class AsStrided : public UnaryPrimitive { + public: + explicit AsStrided(Stream stream, Shape shape, Strides strides, size_t offset) + : UnaryPrimitive(stream), + shape_(std::move(shape)), + strides_(std::move(strides)), + offset_(offset) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_NAME(AsStrided) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(shape_, strides_, offset_); + } + + private: + Shape shape_; + Strides strides_; + size_t offset_; + + void eval(const std::vector& inputs, array& out); +}; + +class BitwiseBinary : public UnaryPrimitive { + public: + enum Op { And, Or, Xor, LeftShift, RightShift }; + + explicit BitwiseBinary(Stream stream, Op op) + : UnaryPrimitive(stream), op_(op) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + + const char* name() const override { + switch (op_) { + case BitwiseBinary::And: + return "BitwiseAnd"; + case BitwiseBinary::Or: + return "BitwiseOr"; + case BitwiseBinary::Xor: + return "BitwiseXor"; + case BitwiseBinary::LeftShift: + return "LeftShift"; + case BitwiseBinary::RightShift: + return "RightShift"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return op_; + } + + private: + Op op_; +}; + +class BitwiseInvert : public UnaryPrimitive { + public: + explicit BitwiseInvert(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(BitwiseInvert) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class BlockMaskedMM : public UnaryPrimitive { + public: + explicit BlockMaskedMM(Stream stream, int block_size) + : UnaryPrimitive(stream), block_size_(block_size) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(BlockMaskedMM) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return block_size_; + } + + private: + int block_size_; +}; + +class GatherMM : public UnaryPrimitive { + public: + explicit GatherMM( + Stream stream, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(GatherMM) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(left_sorted_, right_sorted_); + } + + private: + bool left_sorted_; + bool right_sorted_; +}; + +class SegmentedMM : public UnaryPrimitive { + public: + explicit SegmentedMM(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(SegmentedMM) +}; + +class BroadcastAxes : public UnaryPrimitive { + public: + explicit BroadcastAxes(Stream stream, std::vector ignore_axes = {}) + : UnaryPrimitive(stream), ignore_axes_(std::move(ignore_axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(BroadcastAxes) + bool is_equivalent(const Primitive& other) const override; + static Shape output_shape( + const std::vector& inputs, + const std::vector& ignore_axes); + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return ignore_axes_; + } + + private: + void eval(const std::vector& inputs, array& out); + std::vector ignore_axes_; +}; + +class Broadcast : public UnaryPrimitive { + public: + explicit Broadcast(Stream stream, const Shape& shape) + : UnaryPrimitive(stream), shape_(shape) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Broadcast) + static Shape output_shape(const std::vector& inputs); + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + Shape state() const { + return shape_; + }; + + private: + Shape shape_; + + void eval(const std::vector& inputs, array& out); +}; + +class Ceil : public UnaryPrimitive { + public: + explicit Ceil(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Ceil) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Compiled : public Primitive { + public: + /* + * The inputs, outputs and tape are either tracers or constants. + * - The tape should not contain the inputs, but it should contain the + * outputs. + * - The tape should also have only one array per primitive for multi-output + * primitives. + * - The constant_ids contains ids of arrays in the input list that are safe + * to treat as scalar constants. + */ + explicit Compiled( + Stream stream, + std::vector inputs, + std::vector outputs, + std::vector tape, + std::unordered_set constant_ids); + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + const char* name() const override; + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + std::string lib_name() const { + return kernel_lib_; + } + + private: + const std::vector inputs_; + const std::vector outputs_; + const std::vector tape_; + const std::unordered_set constant_ids_; + const std::function is_constant_; + + mutable std::string name_; + std::string kernel_lib_; +}; + +class Concatenate : public UnaryPrimitive { + public: + explicit Concatenate(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Concatenate) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return axis_; + } + + private: + int axis_; +}; + +class Conjugate : public UnaryPrimitive { + public: + explicit Conjugate(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(Conjugate) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Contiguous : public UnaryPrimitive { + public: + explicit Contiguous(Stream stream, bool allow_col_major) + : UnaryPrimitive(stream), allow_col_major_(allow_col_major) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Contiguous) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + + private: + bool allow_col_major_; +}; + +class Convolution : public UnaryPrimitive { + public: + explicit Convolution( + Stream stream, + const std::vector& kernel_strides, + const std::vector& padding_lo, + const std::vector& padding_hi, + const std::vector& kernel_dilation, + const std::vector& input_dilation, + const int groups = 1, + const bool flip = false) + : UnaryPrimitive(stream), + padding_lo_(padding_lo), + padding_hi_(padding_hi), + kernel_strides_(kernel_strides), + kernel_dilation_(kernel_dilation), + input_dilation_(input_dilation), + groups_(groups), + flip_(flip) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotangents, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_VMAP() + DEFINE_NAME(Convolution) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple( + kernel_strides_, + padding_lo_, + padding_hi_, + kernel_dilation_, + input_dilation_, + groups_, + flip_); + } + + static Shape conv_out_shape( + const Shape& in_shape, + const Shape& wt_shape, + const std::vector& strides, + const std::vector& pads_lo, + const std::vector& pads_hi, + const std::vector& kernel_dilation, + const std::vector& input_dilation); + + private: + std::vector padding_lo_; + std::vector padding_hi_; + std::vector kernel_strides_; + std::vector kernel_dilation_; + std::vector input_dilation_; + int groups_; + bool flip_; +}; + +class Copy : public UnaryPrimitive { + public: + explicit Copy(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Copy) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Cos : public UnaryPrimitive { + public: + explicit Cos(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Cos) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Cosh : public UnaryPrimitive { + public: + explicit Cosh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Cosh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class CustomTransforms : public Primitive { + public: + explicit CustomTransforms( + Stream stream, + int num_outputs, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> vjp, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> jvp, + std::function, std::vector>( + const std::vector&, + const std::vector&)> vmap) + : Primitive(stream), + num_outputs_(num_outputs), + vjp_fun_(std::move(vjp)), + jvp_fun_(std::move(jvp)), + vmap_fun_(std::move(vmap)) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_GRADS(); + DEFINE_VMAP(); + DEFINE_NAME(CustomTransforms); + + private: + void eval(const std::vector& inputs, std::vector& outputs); + + int num_outputs_; + + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> + vjp_fun_; + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> + jvp_fun_; + std::function, std::vector>( + const std::vector&, + const std::vector&)> + vmap_fun_; +}; + +class Depends : public Primitive { + public: + explicit Depends(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + std::vector vjp( + const std::vector& primals, + const std::vector& cotan, + const std::vector& argnums, + const std::vector& outputs) override; + + DEFINE_NAME(Depends); + + private: + void eval(const std::vector& inputs, std::vector& outputs); +}; + +class Divide : public UnaryPrimitive { + public: + explicit Divide(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Divide) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class DivMod : public Primitive { + public: + explicit DivMod(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(DivMod) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override { + return std::vector{inputs[0].shape(), inputs[0].shape()}; + } +}; + +class Select : public UnaryPrimitive { + public: + explicit Select(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Select) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Remainder : public UnaryPrimitive { + public: + explicit Remainder(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Remainder) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Equal : public UnaryPrimitive { + public: + explicit Equal(Stream stream, bool equal_nan = false) + : UnaryPrimitive(stream), equal_nan_(equal_nan) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + const char* name() const override { + if (equal_nan_) { + return "NaNEqual"; + } else { + return "Equal"; + } + } + auto state() const { + return equal_nan_; + }; + + private: + bool equal_nan_; +}; + +class Erf : public UnaryPrimitive { + public: + explicit Erf(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Erf) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ErfInv : public UnaryPrimitive { + public: + explicit ErfInv(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ErfInv) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Exp : public UnaryPrimitive { + public: + explicit Exp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Exp) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Expm1 : public UnaryPrimitive { + public: + explicit Expm1(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Expm1) + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class ExpandDims : public UnaryPrimitive { + public: + explicit ExpandDims(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(ExpandDims) + + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, const std::vector& axes); + auto state() const { + return axes_; + } + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; +}; + +class FFT : public UnaryPrimitive { + public: + explicit FFT( + Stream stream, + const std::vector& axes, + bool inverse, + bool real) + : UnaryPrimitive(stream), axes_(axes), inverse_(inverse), real_(real) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(FFT) + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(axes_, inverse_, real_); + } + + private: + std::vector axes_; + bool inverse_; + bool real_; +}; + +class Flatten : public UnaryPrimitive { + public: + explicit Flatten(Stream stream, int start_axis, int end_axis) + : UnaryPrimitive(stream), start_axis_(start_axis), end_axis_(end_axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Flatten) + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, int start_axis, int end_axis); + auto state() const { + return std::make_pair(start_axis_, end_axis_); + } + + private: + int start_axis_; + int end_axis_; + void eval(const std::vector& inputs, array& out); +}; + +class Floor : public UnaryPrimitive { + public: + explicit Floor(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Floor) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Full : public UnaryPrimitive { + public: + explicit Full(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Full) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Gather : public UnaryPrimitive { + public: + explicit Gather(Stream stream, std::vector axes, Shape slice_sizes) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + slice_sizes_(std::move(slice_sizes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Gather) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::pair, Shape> state() const { + return {axes_, slice_sizes_}; + } + + private: + std::vector axes_; + Shape slice_sizes_; +}; + +class GatherAxis : public UnaryPrimitive { + public: + explicit GatherAxis(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(GatherAxis) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return axis_; + } + + private: + int axis_; +}; + +class Greater : public UnaryPrimitive { + public: + explicit Greater(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Greater) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class GreaterEqual : public UnaryPrimitive { + public: + explicit GreaterEqual(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(GreaterEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Hadamard : public UnaryPrimitive { + public: + explicit Hadamard(Stream stream, float scale) + : UnaryPrimitive(stream), scale_(scale) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Hadamard) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return scale_; + } + + private: + float scale_; +}; + +class Imag : public UnaryPrimitive { + public: + explicit Imag(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Imag) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Less : public UnaryPrimitive { + public: + explicit Less(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Less) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LessEqual : public UnaryPrimitive { + public: + explicit LessEqual(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LessEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Load : public UnaryPrimitive { + public: + explicit Load( + Stream stream, + std::shared_ptr reader, + size_t offset, + bool swap_endianness = false) + : UnaryPrimitive(stream), + reader_(std::move(reader)), + offset_(offset), + swap_endianness_(swap_endianness) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_NAME(Load) + + private: + std::shared_ptr reader_; + size_t offset_; + bool swap_endianness_; +}; + +class Log : public UnaryPrimitive { + public: + enum Base { two, ten, e }; + + explicit Log(Stream stream, Base base) + : UnaryPrimitive(stream), base_(base) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + Base state() const { + return base_; + }; + + const char* name() const override { + switch (base_) { + case e: + return "Log"; + case two: + return "Log2"; + case ten: + return "Log10"; + } + return ""; + } + + private: + Base base_; +}; + +class Log1p : public UnaryPrimitive { + public: + explicit Log1p(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Log1p) + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogicalNot : public UnaryPrimitive { + public: + explicit LogicalNot(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogicalNot) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogicalAnd : public UnaryPrimitive { + public: + explicit LogicalAnd(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogicalAnd) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogicalOr : public UnaryPrimitive { + public: + explicit LogicalOr(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogicalOr) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogAddExp : public UnaryPrimitive { + public: + explicit LogAddExp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogAddExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class LogSumExp : public UnaryPrimitive { + public: + explicit LogSumExp(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(LogSumExp) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override; +}; + +class Matmul : public UnaryPrimitive { + public: + explicit Matmul(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_GRADS() + DEFINE_VMAP() + DEFINE_NAME(Matmul) + DEFINE_DEFAULT_IS_EQUIVALENT() + std::vector output_shapes(const std::vector& inputs) override; +}; + +class Maximum : public UnaryPrimitive { + public: + explicit Maximum(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Maximum) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Minimum : public UnaryPrimitive { + public: + explicit Minimum(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Minimum) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Multiply : public UnaryPrimitive { + public: + explicit Multiply(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Multiply) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Negative : public UnaryPrimitive { + public: + explicit Negative(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Negative) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class NotEqual : public UnaryPrimitive { + public: + explicit NotEqual(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(NotEqual) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class NumberOfElements : public UnaryPrimitive { + public: + explicit NumberOfElements( + Stream stream, + std::vector axes, + bool inverted, + Dtype dtype) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + inverted_(inverted), + dtype_(dtype) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(NumberOfElements) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override { + return {{}}; + } + std::tuple, bool, Dtype> state() const { + return {axes_, inverted_, dtype_}; + } + + private: + std::vector axes_; + bool inverted_; + Dtype dtype_; + + void eval(const std::vector& inputs, array& out); +}; + +class Pad : public UnaryPrimitive { + public: + explicit Pad( + Stream stream, + const std::vector& axes, + const Shape& low_pad_size, + const Shape& high_pad_size) + : UnaryPrimitive(stream), + axes_(axes), + low_pad_size_(low_pad_size), + high_pad_size_(high_pad_size) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Pad) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(axes_, low_pad_size_, high_pad_size_); + } + + private: + std::vector axes_; + Shape low_pad_size_; + Shape high_pad_size_; +}; + +class Partition : public UnaryPrimitive { + public: + explicit Partition(Stream stream, int kth, int axis) + : UnaryPrimitive(stream), kth_(kth), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Partition) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(kth_, axis_); + }; + + private: + int kth_; + int axis_; +}; + +class Power : public UnaryPrimitive { + public: + explicit Power(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Power) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class QuantizedMatmul : public UnaryPrimitive { + public: + explicit QuantizedMatmul( + Stream stream, + int group_size, + int bits, + QuantizationMode mode, + bool transpose) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode), + transpose_(transpose) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(QuantizedMatmul) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(group_size_, bits_, mode_, transpose_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool transpose_; +}; + +class QQMatmul : public UnaryPrimitive { + public: + explicit QQMatmul( + Stream stream, + int group_size, + int bits, + QuantizationMode mode) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + // DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(QQMatmul) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_tuple(group_size_, bits_, mode_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; +}; + +class GatherQMM : public UnaryPrimitive { + public: + explicit GatherQMM( + Stream stream, + int group_size, + int bits, + QuantizationMode mode, + bool transpose, + bool left_sorted = false, + bool right_sorted = false) + : UnaryPrimitive(stream), + group_size_(group_size), + bits_(bits), + mode_(mode), + transpose_(transpose), + left_sorted_(left_sorted), + right_sorted_(right_sorted) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(GatherQMM) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple( + group_size_, bits_, mode_, transpose_, left_sorted_, right_sorted_); + } + + private: + int group_size_; + int bits_; + QuantizationMode mode_; + bool transpose_; + bool left_sorted_; + bool right_sorted_; +}; + +class RandomBits : public UnaryPrimitive { + public: + explicit RandomBits(Stream stream, const Shape& shape, int width) + : UnaryPrimitive(stream), shape_(shape), width_(width) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(RandomBits) + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {shape_, width_}; + }; + + private: + Shape shape_; + int width_; +}; + +class Real : public UnaryPrimitive { + public: + explicit Real(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Real) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Reshape : public UnaryPrimitive { + public: + explicit Reshape(Stream stream, const Shape& shape) + : UnaryPrimitive(stream), shape_(shape) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Reshape) + bool is_equivalent(const Primitive& other) const override; + Shape state() const { + return shape_; + }; + static Shape output_shape(const array& input, Shape shape); + std::vector output_shapes(const std::vector& inputs) override; + + private: + Shape shape_; +}; + +class Reduce : public UnaryPrimitive { + public: + enum ReduceType { And, Or, Sum, Prod, Min, Max }; + + explicit Reduce( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS(); + + std::vector output_shapes(const std::vector& inputs) override; + + const char* name() const override { + switch (reduce_type_) { + case And: + return "And"; + case Or: + return "Or"; + case Sum: + return "Sum"; + case Prod: + return "Prod"; + case Min: + return "Min"; + case Max: + return "Max"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + std::pair> state() const { + return {reduce_type_, axes_}; + }; + + private: + ReduceType reduce_type_; + std::vector axes_; +}; + +class Round : public UnaryPrimitive { + public: + explicit Round(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Round) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Scan : public UnaryPrimitive { + public: + enum ReduceType { Max, Min, Sum, Prod, LogAddExp }; + + explicit Scan( + Stream stream, + ReduceType reduce_type, + int axis, + bool reverse, + bool inclusive) + : UnaryPrimitive(stream), + reduce_type_(reduce_type), + axis_(axis), + reverse_(reverse), + inclusive_(inclusive) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS(); + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "CumSum"; + case Prod: + return "CumProd"; + case Min: + return "CumMin"; + case Max: + return "CumMax"; + case LogAddExp: + return "CumLogAddExp"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(reduce_type_, axis_, reverse_, inclusive_); + } + + private: + ReduceType reduce_type_; + int axis_; + bool reverse_; + bool inclusive_; +}; + +class Scatter : public UnaryPrimitive { + public: + enum ReduceType { Max, Min, Sum, Prod, None }; + + explicit Scatter( + Stream stream, + ReduceType reduce_type, + const std::vector& axes) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axes_(axes) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP(); + DEFINE_GRADS(); + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "Scatter Sum"; + case Prod: + return "Scatter Prod"; + case Min: + return "Scatter Min"; + case Max: + return "Scatter Max"; + case None: + return "Scatter"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + std::pair> state() const { + return {reduce_type_, axes_}; + }; + + private: + ReduceType reduce_type_; + std::vector axes_; +}; + +class ScatterAxis : public UnaryPrimitive { + public: + enum ReduceType { Sum, None }; + + explicit ScatterAxis(Stream stream, ReduceType reduce_type, int axis) + : UnaryPrimitive(stream), reduce_type_(reduce_type), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + + const char* name() const override { + switch (reduce_type_) { + case Sum: + return "ScatterAxis Sum"; + case None: + return "ScatterAxis"; + } + return ""; + } + + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::pair state() const { + return {reduce_type_, axis_}; + } + + private: + ReduceType reduce_type_; + int axis_; +}; + +class MaskedScatter : public UnaryPrimitive { + public: + explicit MaskedScatter(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP(); + DEFINE_GRADS(); + DEFINE_NAME(MaskedScatter); + DEFINE_DEFAULT_IS_EQUIVALENT(); + DEFINE_INPUT_OUTPUT_SHAPE(); +}; + +class Sigmoid : public UnaryPrimitive { + public: + explicit Sigmoid(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sigmoid) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sign : public UnaryPrimitive { + public: + explicit Sign(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sign) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sin : public UnaryPrimitive { + public: + explicit Sin(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sin) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sinh : public UnaryPrimitive { + public: + explicit Sinh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sinh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Slice : public UnaryPrimitive { + public: + explicit Slice( + Stream stream, + const Shape& start_indices, + const Shape& end_indices, + const Shape& strides) + : UnaryPrimitive(stream), + start_indices_(start_indices), + end_indices_(end_indices), + strides_(strides) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Slice) + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_tuple(start_indices_, end_indices_, strides_); + } + + private: + Shape start_indices_; + Shape end_indices_; + Shape strides_; +}; + +class SliceUpdate : public UnaryPrimitive { + public: + explicit SliceUpdate( + Stream stream, + const Shape& start_indices, + const Shape& end_indices, + const Shape& strides) + : UnaryPrimitive(stream), + start_indices_(start_indices), + end_indices_(end_indices), + strides_(strides) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(SliceUpdate) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return std::make_tuple(start_indices_, end_indices_, strides_); + } + + private: + Shape start_indices_; + Shape end_indices_; + Shape strides_; +}; + +class DynamicSlice : public UnaryPrimitive { + public: + explicit DynamicSlice(Stream stream, std::vector axes, Shape slice_size) + : UnaryPrimitive(stream), + axes_(std::move(axes)), + slice_size_(std::move(slice_size)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(DynamicSlice) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + auto state() const { + return std::make_pair(axes_, slice_size_); + } + + private: + std::vector axes_; + Shape slice_size_; +}; + +class DynamicSliceUpdate : public UnaryPrimitive { + public: + explicit DynamicSliceUpdate(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(DynamicSliceUpdate) + bool is_equivalent(const Primitive& other) const override; + DEFINE_INPUT_OUTPUT_SHAPE() + auto state() const { + return axes_; + } + + private: + std::vector axes_; +}; + +class Softmax : public UnaryPrimitive { + public: + explicit Softmax(Stream stream, bool precise) + : UnaryPrimitive(stream), precise_(precise) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Softmax) + DEFINE_INPUT_OUTPUT_SHAPE() + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return precise_; + }; + + private: + bool precise_; +}; + +class Sort : public UnaryPrimitive { + public: + explicit Sort(Stream stream, int axis) + : UnaryPrimitive(stream), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Sort) + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return axis_; + } + + private: + int axis_; +}; + +class Split : public Primitive { + public: + explicit Split(Stream stream, const Shape& indices, int axis) + : Primitive(stream), indices_(indices), axis_(axis) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Split) + bool is_equivalent(const Primitive& other) const override; + std::pair state() const { + return {indices_, axis_}; + }; + + private: + void eval(const std::vector& inputs, std::vector& outputs); + + Shape indices_; + int axis_; +}; + +class Square : public UnaryPrimitive { + public: + explicit Square(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Square) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Sqrt : public UnaryPrimitive { + public: + explicit Sqrt(Stream stream, bool recip = false) + : UnaryPrimitive(stream), recip_(recip) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_INPUT_OUTPUT_SHAPE() + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return recip_; + } + + const char* name() const override { + if (recip_) { + return "Rsqrt"; + } else { + return "Sqrt"; + } + } + + private: + bool recip_; +}; + +class StopGradient : public UnaryPrimitive { + public: + explicit StopGradient(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_NAME(StopGradient) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + void eval(const std::vector& inputs, array& out); +}; + +class Subtract : public UnaryPrimitive { + public: + explicit Subtract(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Subtract) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Squeeze : public UnaryPrimitive { + public: + explicit Squeeze(Stream stream, std::vector axes) + : UnaryPrimitive(stream), axes_(std::move(axes)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Squeeze) + + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, const std::vector& axes); + auto state() const { + return axes_; + }; + + private: + void eval(const std::vector& inputs, array& out); + std::vector axes_; +}; + +class Tan : public UnaryPrimitive { + public: + explicit Tan(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Tan) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Tanh : public UnaryPrimitive { + public: + explicit Tanh(Stream stream) : UnaryPrimitive(stream) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Tanh) + DEFINE_DEFAULT_IS_EQUIVALENT() + DEFINE_INPUT_OUTPUT_SHAPE() +}; + +class Unflatten : public UnaryPrimitive { + public: + explicit Unflatten(Stream stream, int axis, Shape shape) + : UnaryPrimitive(stream), axis_(axis), shape_(std::move(shape)) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Unflatten) + std::vector output_shapes(const std::vector& inputs) override; + bool is_equivalent(const Primitive& other) const override; + + static Shape output_shape(const array& input, int axis, const Shape& shape); + auto state() const { + return std::make_pair(axis_, shape_); + } + + private: + int axis_; + Shape shape_; + void eval(const std::vector& inputs, array& out); +}; + +class View : public UnaryPrimitive { + public: + explicit View(Stream stream, Dtype dtype) + : UnaryPrimitive(stream), dtype_(dtype) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + const char* name() const override; + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return dtype_; + } + + private: + Dtype dtype_; + mutable std::string name_; +}; + +class Transpose : public UnaryPrimitive { + public: + explicit Transpose(Stream stream, const std::vector& axes) + : UnaryPrimitive(stream), axes_(axes) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_GRADS() + DEFINE_NAME(Transpose) + bool is_equivalent(const Primitive& other) const override; + std::vector output_shapes(const std::vector& inputs) override; + std::vector state() const { + return axes_; + }; + + private: + std::vector axes_; + + void eval(const std::vector& inputs, array& out); +}; + +/* QR Factorization primitive. */ +class QRF : public Primitive { + public: + explicit QRF(Stream stream) : Primitive(stream) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(QRF) +}; + +/* SVD primitive. */ +class SVD : public Primitive { + public: + explicit SVD(Stream stream, bool compute_uv) + : Primitive(stream), compute_uv_(compute_uv) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_NAME(SVD) + auto state() const { + return compute_uv_; + } + + private: + bool compute_uv_; +}; + +/* Matrix inversion primitive. */ +class Inverse : public UnaryPrimitive { + public: + explicit Inverse(Stream stream, bool tri, bool upper) + : UnaryPrimitive(stream), tri_(tri), upper_(upper) {} + + void eval_cpu(const std::vector& inputs, array& output) override; + void eval_gpu(const std::vector& inputs, array& output) override; + + DEFINE_VMAP() + DEFINE_NAME(Inverse) + auto state() const { + return std::make_pair(tri_, upper_); + } + + private: + bool tri_; + bool upper_; +}; + +class Cholesky : public UnaryPrimitive { + public: + explicit Cholesky(Stream stream, bool upper) + : UnaryPrimitive(stream), upper_(upper) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + auto state() const { + return upper_; + } + + DEFINE_VMAP() + DEFINE_NAME(Cholesky) + + private: + bool upper_; +}; + +class Eig : public Primitive { + public: + explicit Eig(Stream stream, bool compute_eigenvectors) + : Primitive(stream), compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_NAME(Eig) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return compute_eigenvectors_; + } + + private: + bool compute_eigenvectors_; +}; + +class Eigh : public Primitive { + public: + explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) + : Primitive(stream), + uplo_(std::move(uplo)), + compute_eigenvectors_(compute_eigenvectors) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_VMAP() + DEFINE_NAME(Eigh) + + std::vector output_shapes(const std::vector& inputs) override; + + bool is_equivalent(const Primitive& other) const override; + auto state() const { + return std::make_pair(uplo_, compute_eigenvectors_); + } + + private: + std::string uplo_; + bool compute_eigenvectors_; +}; + +/* LU Factorization primitive. */ +class LUF : public Primitive { + public: + explicit LUF(Stream stream) : Primitive(stream) {} + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + DEFINE_NAME(LUF) +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/random.h b/dist/include/mlx/random.h new file mode 100644 index 0000000..0dfdab7 --- /dev/null +++ b/dist/include/mlx/random.h @@ -0,0 +1,282 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/array.h" +#include "mlx/stream.h" +#include "mlx/utils.h" + +namespace mlx::core::random { + +class KeySequence { + public: + explicit KeySequence(uint64_t seed); + + void seed(uint64_t seed); + array next(); + + // static default + static KeySequence& default_() { + static KeySequence ks(get_current_time_seed()); + return ks; + } + + private: + array key_; + static uint64_t get_current_time_seed() { + auto now = std::chrono::system_clock::now(); + return std::chrono::duration_cast( + now.time_since_epoch()) + .count(); + } +}; + +/** Get a PRNG key from a seed. */ +array key(uint64_t seed); + +/** Seed the default PRNG key. */ +void seed(uint64_t seed); + +/** Generate an array with type uint32 filled with random bits. */ +array bits( + const Shape& shape, + int width, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array bits( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bits(shape, 4, key, s); +} + +/** Split the rng key into a pair of keys. */ +std::pair split(const array& key, StreamOrDevice s = {}); + +/** Split the rng key into `num` keys. */ +array split(const array& key, int num, StreamOrDevice s = {}); + +/** Generate uniform random numbers between low and high. */ +array uniform( + const array& low, + const array& high, + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array uniform( + T low, + U high, + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(array(low), array(high), shape, dtype, key, to_stream(s)); +} + +/** Generate uniform random numbers between 0 and 1. */ +array uniform( + const Shape& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array uniform( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return uniform(shape, float32, key); +} + +/** Generate samples from the standard normal distribution. */ +array normal( + const Shape& shape, + Dtype dtype, + const std::optional& loc, + const std::optional& scale, + const std::optional& key, + StreamOrDevice s = {}); +inline array normal( + const Shape& shape, + Dtype dtype, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + auto loc_ = loc == 0 ? std::nullopt : std::make_optional(array(loc, dtype)); + auto scale_ = + scale == 1 ? std::nullopt : std::make_optional(array(scale, dtype)); + return normal(shape, dtype, loc_, scale_, key, s); +} +inline array normal( + const Shape& shape, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, loc, scale, key, s); +} +inline array normal( + const Shape& shape, + const Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, dtype, std::nullopt, std::nullopt, key, s); +} +inline array normal( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return normal(shape, float32, std::nullopt, std::nullopt, key, s); +} + +/** Generate samples from a multivariate normal distribution. **/ +array multivariate_normal( + const array& mean, + const array& cov, + const Shape& shape, + Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +/** Generate integer samples uniformly at random */ +array randint( + const array& low, + const array& high, + const Shape& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array randint( + T low, + U high, + const Shape& shape, + Dtype dtype = int32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return randint(array(low), array(high), shape, dtype, key, to_stream(s)); +} + +/** Generate binary variables with probability to be true equal to p */ +array bernoulli( + const array& p, + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +array bernoulli( + const array& p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +template +array bernoulli( + T p, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), key, s); +} + +template +array bernoulli( + T p, + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return bernoulli(array(p), shape, key, s); +} + +array bernoulli( + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array truncated_normal( + const array& lower, + const array& upper, + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array truncated_normal( + const array& lower, + const array& upper, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array gumbel( + const Shape& shape, + Dtype dtype = float32, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits, + int axis, + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits_, + int axis, + int num_samples, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +array categorical( + const array& logits, + int axis = -1, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +/** Generate samples from the laplace distribution. */ +array laplace( + const Shape& shape, + Dtype dtype, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); +inline array laplace( + const Shape& shape, + const float loc, + const float scale, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, float32, loc, scale, key, s); +} +inline array laplace( + const Shape& shape, + const Dtype dtype, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, dtype, 0.0, 1.0, key, s); +} +inline array laplace( + const Shape& shape, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}) { + return laplace(shape, float32, 0.0, 1.0, key, s); +} + +/* Randomly permute the elements of x along the given axis. */ +array permutation( + const array& x, + int axis = 0, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +/* A random permutation of `arange(x)` */ +array permutation( + int x, + const std::optional& key = std::nullopt, + StreamOrDevice s = {}); + +} // namespace mlx::core::random diff --git a/dist/include/mlx/scheduler.h b/dist/include/mlx/scheduler.h new file mode 100644 index 0000000..d01d414 --- /dev/null +++ b/dist/include/mlx/scheduler.h @@ -0,0 +1,188 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include +#include + +#include "mlx/backend/gpu/eval.h" +#include "mlx/device.h" +#include "mlx/stream.h" + +namespace mlx::core::scheduler { + +struct StreamThread { + std::mutex mtx; + std::queue> q; + std::condition_variable cond; + bool stop; + std::thread thread; + + StreamThread() : stop(false), thread(&StreamThread::thread_fn, this) {} + + ~StreamThread() { + { + std::lock_guard lk(mtx); + stop = true; + } + cond.notify_one(); + thread.join(); + } + + void thread_fn() { + while (true) { + std::function task; + { + std::unique_lock lk(mtx); + cond.wait(lk, [this] { return !this->q.empty() || this->stop; }); + if (q.empty() && stop) { + return; + } + task = std::move(q.front()); + q.pop(); + } + + task(); + } + } + + template + void enqueue(F&& f) { + { + std::lock_guard lk(mtx); + if (stop) { + throw std::runtime_error( + "Cannot enqueue work after stream is stopped."); + } + q.emplace(std::forward(f)); + } + cond.notify_one(); + } +}; + +class Scheduler { + public: + Scheduler() : n_active_tasks_(0) { + if (is_available(Device::gpu)) { + default_streams_.insert({Device::gpu, new_stream(Device::gpu)}); + } + default_streams_.insert({Device::cpu, new_stream(Device::cpu)}); + } + + // Not copyable or moveable + Scheduler(const Scheduler&) = delete; + Scheduler(Scheduler&&) = delete; + Scheduler& operator=(const Scheduler&) = delete; + Scheduler& operator=(Scheduler&&) = delete; + + Stream new_stream(const Device& d) { + streams_.emplace_back(streams_.size(), d); + if (d == Device::gpu) { + threads_.push_back(nullptr); + gpu::new_stream(streams_.back()); + } else { + threads_.push_back(new StreamThread{}); + } + return streams_.back(); + } + + template + void enqueue(const Stream& stream, F&& f); + + Stream get_default_stream(const Device& d) const { + return default_streams_.at(d.type); + } + Stream get_stream(int index) const { + return streams_.at(index); + } + + void set_default_stream(const Stream& s) { + default_streams_.at(s.device.type) = s; + } + + void notify_new_task(const Stream& stream) { + { + std::lock_guard lk(mtx); + n_active_tasks_++; + } + completion_cv.notify_all(); + } + + void notify_task_completion(const Stream& stream) { + { + std::lock_guard lk(mtx); + n_active_tasks_--; + } + completion_cv.notify_all(); + } + + int n_active_tasks() const { + return n_active_tasks_; + } + + void wait_for_one() { + std::unique_lock lk(mtx); + int n_tasks_old = n_active_tasks(); + if (n_tasks_old > 1) { + completion_cv.wait(lk, [this, n_tasks_old] { + return this->n_active_tasks() < n_tasks_old; + }); + } + } + + ~Scheduler() { + for (auto s : streams_) { + try { + synchronize(s); + } catch (const std::runtime_error&) { + // ignore errors if synch fails + } + } + for (auto t : threads_) { + if (t != nullptr) { + delete t; + } + } + } + + private: + int n_active_tasks_; + std::vector threads_; + std::vector streams_; + std::unordered_map default_streams_; + std::condition_variable completion_cv; + std::mutex mtx; +}; + +template +void Scheduler::enqueue(const Stream& stream, F&& f) { + threads_[stream.index]->enqueue(std::forward(f)); +} + +Scheduler& scheduler(); + +template +void enqueue(const Stream& stream, F&& f) { + scheduler().enqueue(stream, std::forward(f)); +} + +inline int n_active_tasks() { + return scheduler().n_active_tasks(); +} + +inline void notify_new_task(const Stream& stream) { + scheduler().notify_new_task(stream); +} + +inline void notify_task_completion(const Stream& stream) { + scheduler().notify_task_completion(stream); +} + +inline void wait_for_one() { + scheduler().wait_for_one(); +} + +} // namespace mlx::core::scheduler diff --git a/dist/include/mlx/small_vector.h b/dist/include/mlx/small_vector.h new file mode 100644 index 0000000..143101c --- /dev/null +++ b/dist/include/mlx/small_vector.h @@ -0,0 +1,540 @@ +// Copyright © 2025 Apple Inc. +// Copyright © 2018 the V8 project authors. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following +// disclaimer in the documentation and/or other materials provided +// with the distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#pragma once + +#include +#include +#include +#include + +namespace mlx::core { + +#if defined(__has_builtin) +#define MLX_HAS_BUILTIN(x) __has_builtin(x) +#else +#define MLX_HAS_BUILTIN(x) 0 +#endif + +#if defined(__has_attribute) +#define MLX_HAS_ATTRIBUTE(x) __has_attribute(x) +#else +#define MLX_HAS_ATTRIBUTE(x) 0 +#endif + +#if MLX_HAS_BUILTIN(__builtin_expect) +#define MLX_LIKELY(condition) (__builtin_expect(!!(condition), 1)) +#define MLX_UNLIKELY(condition) (__builtin_expect(!!(condition), 0)) +#else +#define MLX_LIKELY(condition) (condition) +#define MLX_UNLIKELY(condition) (condition) +#endif + +#if MLX_HAS_ATTRIBUTE(noinline) +#define MLX_NOINLINE __attribute__((noinline)) +#else +#define MLX_NOINLINE +#endif + +template +struct is_iterator : std::false_type {}; + +template +struct is_iterator< + T, + std::void_t< + typename std::iterator_traits::difference_type, + typename std::iterator_traits::iterator_category, + typename std::iterator_traits::pointer, + typename std::iterator_traits::reference, + typename std::iterator_traits::value_type>> : std::true_type {}; + +template +constexpr bool is_iterator_v = is_iterator::value; + +// Minimal SmallVector implementation. Uses inline storage first, switches to +// dynamic storage when it overflows. +// +// Notes: +// * The default inline storage size is MAX_NDIM, as it is mainly used for +// shapes and strides, users should choose a better size for other cases. +// * The data() returns real address even for empty vector. +// * The pointer returned by data() will change after moving the vector as it +// points to the inline storage. +// * For trivial elements the storage will not be default constructed, +// i.e. SmallVector(10) will not be filled with 0 by default. +template > +class SmallVector { + public: + using value_type = T; + using reference = T&; + using const_reference = const T&; + using iterator = T*; + using const_iterator = const T*; + using difference_type = std::ptrdiff_t; + using size_type = std::size_t; + + SmallVector() = default; + + explicit SmallVector(const Allocator& allocator) : allocator_(allocator) {} + + explicit SmallVector(size_t size, const Allocator& allocator = Allocator()) + : allocator_(allocator) { + resize(size); + } + + SmallVector( + size_t size, + const T& initial_value, + const Allocator& allocator = Allocator()) + : allocator_(allocator) { + resize(size, initial_value); + } + + SmallVector( + std::initializer_list init, + const Allocator& allocator = Allocator()) + : allocator_(allocator) { + if (init.size() > capacity()) { + grow(init.size()); + } + assert(capacity() >= init.size()); // sanity check + std::uninitialized_move(init.begin(), init.end(), begin_); + end_ = begin_ + init.size(); + } + + template >> + SmallVector(Iter begin, Iter end, const Allocator& allocator = Allocator()) + : allocator_(allocator) { + size_t size = std::distance(begin, end); + if (size > capacity()) { + grow(size); + } + assert(capacity() >= size); // sanity check + std::uninitialized_copy(begin, end, begin_); + end_ = begin_ + size; + } + + SmallVector(const SmallVector& other) : allocator_(other.allocator_) { + *this = other; + } + SmallVector(const SmallVector& other, const Allocator& allocator) + : allocator_(allocator) { + *this = other; + } + SmallVector(SmallVector&& other) : allocator_(std::move(other.allocator_)) { + *this = std::move(other); + } + SmallVector(SmallVector&& other, const Allocator& allocator) + : allocator_(allocator) { + *this = std::move(other); + } + + ~SmallVector() { + free_storage(); + } + + SmallVector& operator=(const SmallVector& other) { + if (this == &other) { + return *this; + } + size_t other_size = other.size(); + if (capacity() < other_size) { + // Create large-enough heap-allocated storage. + free_storage(); + begin_ = allocator_.allocate(other_size); + end_of_storage_ = begin_ + other_size; + std::uninitialized_copy(other.begin_, other.end_, begin_); + } else if constexpr (kHasTrivialElement) { + std::copy(other.begin_, other.end_, begin_); + } else { + ptrdiff_t to_copy = + std::min(static_cast(other_size), end_ - begin_); + std::copy(other.begin_, other.begin_ + to_copy, begin_); + if (other.begin_ + to_copy < other.end_) { + std::uninitialized_copy( + other.begin_ + to_copy, other.end_, begin_ + to_copy); + } else { + std::destroy_n(begin_ + to_copy, size() - to_copy); + } + } + end_ = begin_ + other_size; + return *this; + } + + SmallVector& operator=(SmallVector&& other) { + if (this == &other) { + return *this; + } + if (other.is_big()) { + free_storage(); + begin_ = other.begin_; + end_ = other.end_; + end_of_storage_ = other.end_of_storage_; + } else { + assert(capacity() >= other.size()); // sanity check + size_t other_size = other.size(); + if constexpr (kHasTrivialElement) { + std::move(other.begin_, other.end_, begin_); + } else { + ptrdiff_t to_move = + std::min(static_cast(other_size), end_ - begin_); + std::move(other.begin_, other.begin_ + to_move, begin_); + if (other.begin_ + to_move < other.end_) { + std::uninitialized_move( + other.begin_ + to_move, other.end_, begin_ + to_move); + } else { + std::destroy_n(begin_ + to_move, size() - to_move); + } + } + end_ = begin_ + other_size; + } + other.reset_to_inline_storage(); + return *this; + } + + bool operator==(const SmallVector& other) const { + if (size() != other.size()) { + return false; + } + return std::equal(begin_, end_, other.begin_); + } + + bool operator!=(const SmallVector& other) const { + return !(*this == other); + } + + T* data() { + return begin_; + } + const T* data() const { + return begin_; + } + + iterator begin() { + return begin_; + } + const_iterator begin() const { + return begin_; + } + + iterator end() { + return end_; + } + const_iterator end() const { + return end_; + } + + const_iterator cbegin() const { + return begin_; + } + + const_iterator cend() const { + return end_; + } + + auto rbegin() { + return std::make_reverse_iterator(end_); + } + auto rbegin() const { + return std::make_reverse_iterator(end_); + } + + auto rend() { + return std::make_reverse_iterator(begin_); + } + auto rend() const { + return std::make_reverse_iterator(begin_); + } + + size_t size() const { + return end_ - begin_; + } + bool empty() const { + return end_ == begin_; + } + size_t capacity() const { + return end_of_storage_ - begin_; + } + + T& front() { + assert(size() != 0); + return begin_[0]; + } + const T& front() const { + assert(size() != 0); + return begin_[0]; + } + + T& back() { + assert(size() != 0); + return end_[-1]; + } + const T& back() const { + assert(size() != 0); + return end_[-1]; + } + + T& at(size_t index) { + if (index >= size()) { + throw std::out_of_range("SmallVector out of range."); + } + return begin_[index]; + } + const T& at(size_t index) const { + return const_cast(this)->at(index); + } + + T& operator[](size_t index) { + assert(size() > index); + return begin_[index]; + } + const T& operator[](size_t index) const { + return const_cast(this)->operator[](index); + } + + template + void emplace_back(Args&&... args) { + if (MLX_UNLIKELY(end_ == end_of_storage_)) { + grow(); + } + void* storage = end_; + end_ += 1; + new (storage) T(std::forward(args)...); + } + + void push_back(T x) { + emplace_back(std::move(x)); + } + + void pop_back(size_t count = 1) { + assert(size() >= count); + end_ -= count; + std::destroy_n(end_, count); + } + + iterator insert(iterator pos, T value) { + return insert(pos, static_cast(1), std::move(value)); + } + + iterator insert(iterator pos, size_t count, T value) { + assert(pos <= end_); + size_t offset = pos - begin_; + size_t old_size = size(); + resize(old_size + count); + pos = begin_ + offset; + iterator old_end = begin_ + old_size; + assert(old_end <= end_); + std::move_backward(pos, old_end, end_); + if constexpr (kHasTrivialElement) { + std::fill_n(pos, count, value); + } else { + std::fill_n(pos + 1, count - 1, value); + *pos = std::move(value); + } + return pos; + } + + template >> + iterator insert(iterator pos, Iter begin, Iter end) { + if constexpr (std::is_same_v, iterator>) { + // The implementation can not take overlapping range. + assert(!(begin >= pos && begin < pos + std::distance(begin, end))); + assert(!(end > pos && end <= pos + std::distance(begin, end))); + } + + assert(pos <= end_); + size_t offset = pos - begin_; + size_t count = std::distance(begin, end); + size_t old_size = size(); + resize(old_size + count); + pos = begin_ + offset; + iterator old_end = begin_ + old_size; + assert(old_end <= end_); + std::move_backward(pos, old_end, end_); + std::copy(begin, end, pos); + return pos; + } + + iterator insert(iterator pos, std::initializer_list values) { + return insert(pos, values.begin(), values.end()); + } + + iterator erase(iterator erase_start, iterator erase_end) { + assert(erase_start >= begin_); + assert(erase_start <= erase_end); + assert(erase_end <= end_); + iterator new_end = std::move(erase_end, end_, erase_start); + std::destroy_n(new_end, std::distance(new_end, end_)); + end_ = new_end; + return erase_start; + } + + iterator erase(iterator pos) { + return erase(pos, pos + 1); + } + + void resize(size_t new_size) { + if (new_size > capacity()) { + grow(new_size); + } + T* new_end = begin_ + new_size; + if constexpr (!kHasTrivialElement) { + if (new_end > end_) { + std::uninitialized_default_construct(end_, new_end); + } else { + std::destroy_n(new_end, end_ - new_end); + } + } + end_ = new_end; + } + + void resize(size_t new_size, const T& initial_value) { + if (new_size > capacity()) { + grow(new_size); + } + T* new_end = begin_ + new_size; + if (new_end > end_) { + std::uninitialized_fill(end_, new_end, initial_value); + } else { + std::destroy_n(new_end, end_ - new_end); + } + end_ = new_end; + } + + void reserve(size_t new_capacity) { + if (new_capacity > capacity()) { + grow(new_capacity); + } + } + + // Clear without reverting back to inline storage. + void clear() { + std::destroy_n(begin_, end_ - begin_); + end_ = begin_; + } + + private: + // Grows the backing store by a factor of two, and at least to {min_capacity}. + // TODO: Move to private after removing external code using this method. + MLX_NOINLINE void grow(size_t min_capacity = 0) { + size_t new_capacity = std::max(min_capacity, 2 * capacity()); + // Round up to power of 2. + new_capacity--; + new_capacity |= new_capacity >> 1; + new_capacity |= new_capacity >> 2; + new_capacity |= new_capacity >> 4; + new_capacity |= new_capacity >> 8; + new_capacity |= new_capacity >> 16; + if constexpr (sizeof(size_t) == sizeof(uint64_t)) { + new_capacity |= new_capacity >> 32; + } + new_capacity++; + + T* new_storage = allocator_.allocate(new_capacity); + if (new_storage == nullptr) { + throw std::bad_alloc(); + } + + size_t in_use = end_ - begin_; + std::uninitialized_move(begin_, end_, new_storage); + free_storage(); + begin_ = new_storage; + end_ = new_storage + in_use; + end_of_storage_ = new_storage + new_capacity; + } + + MLX_NOINLINE void free_storage() { + std::destroy_n(begin_, end_ - begin_); + if (is_big()) { + allocator_.deallocate(begin_, end_of_storage_ - begin_); + } + } + + // Clear and go back to inline storage. Dynamic storage is *not* freed. For + // internal use only. + void reset_to_inline_storage() { + if constexpr (!kHasTrivialElement) { + if (!is_big()) + std::destroy_n(begin_, end_ - begin_); + } + begin_ = inline_storage_begin(); + end_ = begin_; + end_of_storage_ = begin_ + kSize; + } + + bool is_big() const { + return begin_ != inline_storage_begin(); + } + + T* inline_storage_begin() { + return reinterpret_cast(inline_storage_); + } + const T* inline_storage_begin() const { + return reinterpret_cast(inline_storage_); + } + + Allocator allocator_; + + // Invariants: + // 1. The elements in the range between `begin_` (included) and `end_` (not + // included) will be initialized at all times. + // 2. All other elements outside the range, both in the inline storage and in + // the dynamic storage (if it exists), will be uninitialized at all times. + + T* begin_ = inline_storage_begin(); + T* end_ = begin_; + T* end_of_storage_ = begin_ + kSize; + + alignas(T) char inline_storage_[sizeof(T) * kSize]; + + static constexpr bool kHasTrivialElement = + std::is_trivially_copyable::value && + std::is_trivially_destructible::value; +}; + +template +struct is_vector : std::false_type {}; + +template +struct is_vector> : std::true_type {}; + +template +struct is_vector> : std::true_type {}; + +template +inline constexpr bool is_vector_v = is_vector::value; + +#undef MLX_HAS_BUILTIN +#undef MLX_HAS_ATTRIBUTE +#undef MLX_LIKELY +#undef MLX_UNLIKELY +#undef MLX_NOINLINE + +} // namespace mlx::core diff --git a/dist/include/mlx/stream.h b/dist/include/mlx/stream.h new file mode 100644 index 0000000..3ced403 --- /dev/null +++ b/dist/include/mlx/stream.h @@ -0,0 +1,41 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include "mlx/device.h" + +namespace mlx::core { + +struct Stream { + int index; + Device device; + explicit Stream(int index, Device device) : index(index), device(device) {} +}; + +/** Get the default stream for the given device. */ +Stream default_stream(Device d); + +/** Make the stream the default for its device. */ +void set_default_stream(Stream s); + +/** Make a new stream on the given device. */ +Stream new_stream(Device d); + +/** Get the stream with the given index. */ +Stream get_stream(int index); + +inline bool operator==(const Stream& lhs, const Stream& rhs) { + return lhs.index == rhs.index; +} + +inline bool operator!=(const Stream& lhs, const Stream& rhs) { + return !(lhs == rhs); +} + +/* Synchronize with the default stream. */ +void synchronize(); + +/* Synchronize with the provided stream. */ +void synchronize(Stream); + +} // namespace mlx::core diff --git a/dist/include/mlx/threadpool.h b/dist/include/mlx/threadpool.h new file mode 100644 index 0000000..b0e56d0 --- /dev/null +++ b/dist/include/mlx/threadpool.h @@ -0,0 +1,133 @@ +// This code was modified from https://github.com/progschj/ThreadPool +// The original License is copied below: +// +// Copyright (c) 2012 Jakob Progsch, Václav Zeman +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. +// +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: +// +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. +// +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. +// +// 3. This notice may not be removed or altered from any source +// distribution. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { + public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + -> std::future>; + void resize(size_t); + ~ThreadPool(); + + private: + void stop_and_wait(); + void start_threads(size_t); + + std::vector workers; + std::queue> tasks; + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +inline ThreadPool::ThreadPool(size_t threads) : stop(false) { + start_threads(threads); +} + +template +auto ThreadPool::enqueue(F&& f, Args&&... args) + -> std::future> { + using return_type = typename std::invoke_result_t; + + auto task = std::make_shared>( + std::bind(std::forward(f), std::forward(args)...)); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + if (stop) { + throw std::runtime_error( + "[ThreadPool::enqueue] Not allowed on stopped ThreadPool"); + } + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +inline void ThreadPool::resize(size_t threads) { + if (workers.size() == threads) { + return; + } + + if (workers.size() > threads) { + stop_and_wait(); + } + start_threads(threads - workers.size()); +} + +inline ThreadPool::~ThreadPool() { + stop_and_wait(); +} + +inline void ThreadPool::stop_and_wait() { + // Stop the current threads and wait until they finish + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) { + worker.join(); + } + + // Reset the member variables so that the threadpool is reusable + stop = false; + workers.clear(); +} + +inline void ThreadPool::start_threads(size_t threads) { + for (size_t i = 0; i < threads; ++i) { + workers.emplace_back([this] { + for (;;) { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait( + lock, [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + }); + } +} diff --git a/dist/include/mlx/transforms.h b/dist/include/mlx/transforms.h new file mode 100644 index 0000000..4afb21e --- /dev/null +++ b/dist/include/mlx/transforms.h @@ -0,0 +1,229 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include + +#include "mlx/array.h" + +namespace mlx::core { + +void async_eval(std::vector outputs); + +template > +void async_eval(Arrays&&... outputs) { + async_eval(std::vector{std::forward(outputs)...}); +} + +void eval(std::vector outputs); + +template > +void eval(Arrays&&... outputs) { + eval(std::vector{std::forward(outputs)...}); +} + +/** + * Computes the output and vector-Jacobian product (VJP) of a function. + * + * Computes the vector-Jacobian product of the vector of cotangents with the + * Jacobian of the function evaluated at the primals. Returns a pair of + * vectors of output arrays and VJP arrays. + **/ +std::pair, std::vector> vjp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& cotangents); + +/** + * Computes the output and vector-Jacobian product (VJP) of a unary function. + */ +std::pair vjp( + const std::function& fun, + const array& primal, + const array& cotangent); + +/** + * Computes the output and Jacobian-vector product (JVP) of a function. + * + * Computes the Jacobian-vector product of the Jacobian of the function + * evaluated at the primals with the vector of tangents. Returns a pair of + * vectors of output arrays and JVP arrays. + **/ +std::pair, std::vector> jvp( + const std::function(const std::vector&)>& fun, + const std::vector& primals, + const std::vector& tangents); + +/** + * Computes the output and Jacobian-vector product (JVP) of a unary function. + */ +std::pair jvp( + const std::function& fun, + const array& primal, + const array& tangent); + +// Return type of general value_and_grad: a function which takes an input +// vector of arrays and returns a pair of vectors of arrays one for the +// values and one for the gradients wrt the first value. +using ValueAndGradFn = + std::function, std::vector>( + const std::vector&)>; +using SimpleValueAndGradFn = std::function>( + const std::vector&)>; + +/** + * Returns a function which computes the value and gradient of the input + * function with respect to a vector of input arrays. + **/ +ValueAndGradFn value_and_grad( + const std::function(const std::vector&)>& fun, + const std::vector& argnums); + +/** + * Returns a function which computes the value and gradient of the input + * function with respect to a single input array. + **/ +ValueAndGradFn inline value_and_grad( + const std::function(const std::vector&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the value and gradient of the unary + * input function. + **/ +std::function(const array&)> inline value_and_grad( + const std::function& fun) { + return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + const std::vector& argnums) { + return [fun, argnums](auto inputs) { + auto result = value_and_grad( + [fun](auto inputs) { return std::vector{fun(inputs)}; }, + argnums)(inputs); + + return std::make_pair(result.first[0], result.second); + }; +} + +SimpleValueAndGradFn inline value_and_grad( + const std::function&)>& fun, + int argnum = 0) { + return value_and_grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the input function with + * respect to a vector of input arrays. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The vector of `argnums` specifies which the arguments to compute + * the gradient with respect to. At least one argument must be specified. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + const std::vector& argnums) { + auto fn = value_and_grad(fun, argnums); + return [fn](const std::vector& inputs) { return fn(inputs).second; }; +} + +/** + * Returns a function which computes the gradient of the input function with + * respect to a single input array. + * + * The function being differentiated takes a vector of arrays and returns an + * array. The optional `argnum` index specifies which the argument to compute + * the gradient with respect to and defaults to 0. + **/ +std::function(const std::vector&)> inline grad( + const std::function&)>& fun, + int argnum = 0) { + return grad(fun, std::vector{argnum}); +} + +/** + * Returns a function which computes the gradient of the unary input function. + **/ +std::function inline grad( + const std::function& fun) { + auto fn = value_and_grad(fun); + return [fn](const array& input) { return fn(input).second; }; +} + +/** + * Automatically vectorize a unary function over the requested axes. + */ +std::function vmap( + const std::function& fun, + int in_axis = 0, + int out_axis = 0); + +/** + * Automatically vectorize a binary function over the requested axes. + */ +std::function vmap( + const std::function& fun, + int in_axis_a = 0, + int in_axis_b = 0, + int out_axis = 0); + +/** + * Automatically vectorize a function over the requested axes. + * + * The input function to `vmap` takes as an argument a vector of arrays and + * returns a vector of arrays. Optionally specify the axes to vectorize over + * with `in_axes` and `out_axes`, otherwise a default of 0 is used. + * Returns a vectorized function with the same signature as the input + * function. + */ +std::function(const std::vector&)> vmap( + const std::function(const std::vector&)>& fun, + const std::vector& in_axes = {}, + const std::vector& out_axes = {}); + +/** + * Redefine the transformations of `fun` according to the provided functions. + * + * Namely when calling the vjp of `fun` then `fun_vjp` will be called, + * `fun_jvp` for the jvp and `fun_vmap` for vmap. + * + * If any transformation is not provided, then a default one is created by + * calling `vjp`, `jvp` and `vmap` on the function directly. + */ +std::function(const std::vector&)> custom_function( + std::function(const std::vector&)> fun, + std::optional( + const std::vector&, + const std::vector&, + const std::vector&)>> fun_vjp = std::nullopt, + std::optional( + const std::vector&, + const std::vector&, + const std::vector&)>> fun_jvp = std::nullopt, + std::optional, std::vector>( + const std::vector&, + const std::vector&)>> fun_vmap = std::nullopt); + +/** + * Return a function that behaves exactly like `fun` but if the vjp of the + * results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` . + */ +std::function(const std::vector&)> custom_vjp( + std::function(const std::vector&)> fun, + std::function( + const std::vector&, + const std::vector&, + const std::vector&)> fun_vjp); + +/** + * Checkpoint the gradient of a function. Namely, discard all intermediate + * state and recalculate it when we need to compute the gradient. + */ +std::function(const std::vector&)> checkpoint( + std::function(const std::vector&)> fun); + +} // namespace mlx::core diff --git a/dist/include/mlx/transforms_impl.h b/dist/include/mlx/transforms_impl.h new file mode 100644 index 0000000..46851fa --- /dev/null +++ b/dist/include/mlx/transforms_impl.h @@ -0,0 +1,86 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +namespace mlx::core::detail { + +std::pair, std::vector> vmap_trace( + const std::function(const std::vector&)>& fun, + const std::vector& inputs, + const std::vector& in_axes); + +std::vector vmap_replace( + const std::vector& inputs, + const std::vector& s_inputs, + const std::vector& s_outputs, + const std::vector& in_axes, + const std::vector& out_axes); + +// Create an InTracing object during tracing operations to signify to the rest +// of the codebase that we are during tracing so evals should not throw away +// the graph. +struct InTracing { + explicit InTracing(bool dynamic = false, bool grad = false) { + grad_counter += grad; + trace_stack().push_back({dynamic, grad}); + } + ~InTracing() { + grad_counter -= trace_stack().back().second; + trace_stack().pop_back(); + } + + static bool in_tracing() { + return !trace_stack().empty(); + } + static bool in_dynamic_tracing() { + // compile is always and only the outer-most transform + return in_tracing() && trace_stack().front().first; + } + + static bool in_grad_tracing() { + return grad_counter > 0; + } + + private: + static int grad_counter; + static std::vector>& trace_stack(); +}; + +struct RetainGraph { + RetainGraph() { + tracing_counter++; + } + ~RetainGraph() { + tracing_counter--; + } + + static bool retain_graph() { + return tracing_counter > 0; + } + + private: + static int tracing_counter; +}; + +/** Return true if we are currently performing a function transformation in + * order to keep the graph when evaluating tracer arrays. */ +inline bool in_tracing() { + return detail::InTracing::in_tracing(); +} + +/** Return true if we are in a dynamic (shapeless) trace used for compiling or + * exporting graphs with dynamic shapes. */ +inline bool in_dynamic_tracing() { + return detail::InTracing::in_dynamic_tracing(); +} + +/** Return true if we are in a gradient trace (vjp, jvp, etc). */ +inline bool in_grad_tracing() { + return detail::InTracing::in_grad_tracing(); +} + +inline bool retain_graph() { + return detail::RetainGraph::retain_graph(); +} + +} // namespace mlx::core::detail diff --git a/dist/include/mlx/types/bf16.h b/dist/include/mlx/types/bf16.h new file mode 100644 index 0000000..5951941 --- /dev/null +++ b/dist/include/mlx/types/bf16.h @@ -0,0 +1,187 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#define __MLX_BFLOAT_NAN__ 0x7FC0 + +namespace mlx::core { + +namespace { +union float_bits_bf16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_BFloat16 { + uint16_t bits_; + + // Default constructor + _MLX_BFloat16() = default; + + // Default copy constructor + _MLX_BFloat16(_MLX_BFloat16 const&) = default; + + // Appease std::vector for being special + _MLX_BFloat16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_BFloat16& operator=(const float& x) { + return (*this = _MLX_BFloat16(x)); + } + + // From float32 + _MLX_BFloat16(const float& x) { + if (std::isnan(x)) { + bits_ = __MLX_BFLOAT_NAN__; + } else { + // Union + float_bits_bf16 in; + + // Take bits + in.f = x; + + // Round to nearest even + in.u += (in.u >> 16 & 1) + uint32_t(0x7FFF); + + // Take upper 16 bits + bits_ = in.u >> 16; + } + } + + // To float32 + operator float() const { + // Union + float_bits_bf16 out; + + // Upper 16 bits are the data and lower 16 bits are 0s + out.u = ((uint32_t)bits_) << 16; + + return out.f; + } +}; + +#define bfloat_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define bfloat_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_BFloat16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_BFloat16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define bfloat_binop(_op_, _operator_) \ + bfloat_binop_base( \ + _op_, _operator_, _MLX_BFloat16, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(_op_, _operator_, float, float, float); \ + bfloat_binop_helper(_op_, _operator_, double, double, double); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, bool, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint32_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, int64_t, float); \ + bfloat_binop_helper(_op_, _operator_, _MLX_BFloat16, uint64_t, float); + +bfloat_binop(+, operator+); +bfloat_binop(-, operator-); +bfloat_binop(*, operator*); +bfloat_binop(/, operator/); + +#undef bfloat_binop + +// Comparison ops +#define bfloat_compop(__op__, __operator__) \ + bfloat_binop_base( \ + __op__, __operator__, bool, _MLX_BFloat16, _MLX_BFloat16, float); \ + bfloat_binop_helper(__op__, __operator__, bool, float, float); \ + bfloat_binop_helper(__op__, __operator__, bool, double, double); \ + bfloat_binop_helper(__op__, __operator__, bool, int32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, int64_t, float); \ + bfloat_binop_helper(__op__, __operator__, bool, uint64_t, float); + +bfloat_compop(>, operator>); +bfloat_compop(<, operator<); +bfloat_compop(>=, operator>=); +bfloat_compop(<=, operator<=); +bfloat_compop(==, operator==); +bfloat_compop(!=, operator!=); + +#undef bfloat_compop + +// Negative +inline _MLX_BFloat16 operator-(_MLX_BFloat16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define bfloat_inplace_op(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_BFloat16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_op(+, operator+=); +bfloat_inplace_op(-, operator-=); +bfloat_inplace_op(*, operator*=); +bfloat_inplace_op(/, operator/=); + +#undef bfloat_inplace_op + +// Bitwise ops + +#define bfloat_bitop(__op__, __operator__) \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(_MLX_BFloat16 lhs, uint16_t rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_BFloat16 __operator__(uint16_t lhs, _MLX_BFloat16 rhs) { \ + _MLX_BFloat16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +bfloat_bitop(|, operator|); +bfloat_bitop(&, operator&); +bfloat_bitop(^, operator^); + +#undef bfloat_bitop + +#define bfloat_inplace_bitop(__op__, __operator__) \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, _MLX_BFloat16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_BFloat16& __operator__(_MLX_BFloat16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +bfloat_inplace_bitop(|, operator|=); +bfloat_inplace_bitop(&, operator&=); +bfloat_inplace_bitop(^, operator^=); + +#undef bfloat_inplace_bitop + +} // namespace mlx::core diff --git a/dist/include/mlx/types/complex.h b/dist/include/mlx/types/complex.h new file mode 100644 index 0000000..51101cc --- /dev/null +++ b/dist/include/mlx/types/complex.h @@ -0,0 +1,113 @@ +// Copyright © 2023 Apple Inc. + +#pragma once +#include +#include "mlx/types/half_types.h" + +namespace mlx::core { + +struct complex64_t; +struct complex128_t; + +template +inline constexpr bool can_convert_to_complex128 = + !std::is_same_v && std::is_convertible_v; + +struct complex128_t : public std::complex { + complex128_t() : std::complex() {}; + complex128_t(double v, double u) : std::complex(v, u) {}; + complex128_t(std::complex v) : std::complex(v) {}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex128_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +template +inline constexpr bool can_convert_to_complex64 = + !std::is_same_v && std::is_convertible_v; + +struct complex64_t : public std::complex { + complex64_t() : std::complex() {}; + complex64_t(float v, float u) : std::complex(v, u) {}; + complex64_t(std::complex v) : std::complex(v) {}; + + template < + typename T, + typename = typename std::enable_if>::type> + complex64_t(T x) : std::complex(x){}; + + operator float() const { + return real(); + }; +}; + +inline bool operator>=(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || + (a.real() == b.real() && a.imag() >= b.imag()); +} + +inline bool operator>(const complex64_t& a, const complex64_t& b) { + return (a.real() > b.real()) || (a.real() == b.real() && a.imag() > b.imag()); +} + +inline complex64_t operator%(complex64_t a, complex64_t b) { + auto real = a.real() - (b.real() * static_cast(a.real() / b.real())); + auto imag = a.imag() - (b.imag() * static_cast(a.imag() / b.imag())); + if (real != 0 && ((real < 0) != (b.real() < 0))) + real += b.real(); + if (imag != 0 && ((imag < 0) != (b.imag() < 0))) + imag += b.imag(); + return {real, imag}; +} + +inline bool operator<=(const complex64_t& a, const complex64_t& b) { + return operator>=(b, a); +} + +inline bool operator<(const complex64_t& a, const complex64_t& b) { + return operator>(b, a); +} + +inline complex64_t operator-(const complex64_t& v) { + return -static_cast>(v); +} + +// clang-format off +#define complex_binop_helper(_op_, _operator_, itype) \ + inline complex64_t _operator_(itype x, const complex64_t& y) { \ + return static_cast(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, itype y) { \ + return x _op_ static_cast(y); \ + } + +#define complex_binop(_op_, _operator_) \ + inline complex64_t _operator_(const std::complex& x, const complex64_t& y) { \ + return x _op_ static_cast>(y); \ + } \ + inline complex64_t _operator_(const complex64_t& x, const std::complex& y) { \ + return static_cast>(x) _op_ y; \ + } \ + inline complex64_t _operator_(const complex64_t& x, const complex64_t& y) { \ + return static_cast>(x) \ + _op_ static_cast>(y); \ + } \ + complex_binop_helper(_op_, _operator_, bool) \ + complex_binop_helper(_op_, _operator_, uint32_t) \ + complex_binop_helper(_op_, _operator_, uint64_t) \ + complex_binop_helper(_op_, _operator_, int32_t) \ + complex_binop_helper(_op_, _operator_, int64_t) \ + complex_binop_helper(_op_, _operator_, float16_t) \ + complex_binop_helper(_op_, _operator_, bfloat16_t) \ + complex_binop_helper(_op_, _operator_, float) +// clang-format on + +complex_binop(+, operator+) + +} // namespace mlx::core diff --git a/dist/include/mlx/types/fp16.h b/dist/include/mlx/types/fp16.h new file mode 100644 index 0000000..c174afe --- /dev/null +++ b/dist/include/mlx/types/fp16.h @@ -0,0 +1,234 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#include +#include +#include +#include + +#define __MLX_HALF_NAN__ 0x7D00 + +namespace mlx::core { + +namespace { +union float_bits_fp16 { + float f; + uint32_t u; +}; +} // namespace + +struct _MLX_Float16 { + uint16_t bits_; + + // Default constructor + _MLX_Float16() = default; + + // Default copy constructor + _MLX_Float16(_MLX_Float16 const&) = default; + + // Appease std::vector for being special + _MLX_Float16& operator=(std::vector::reference x) { + bits_ = x; + return *this; + } + + _MLX_Float16& operator=(const float& x) { + return (*this = _MLX_Float16(x)); + } + + // From float32 + _MLX_Float16(const float& x) : bits_(0) { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 in; + + // Take fp32 bits + in.f = x; + + // Find and take sign bit + uint32_t x_sign_32 = in.u & uint32_t(0x80000000); + uint16_t x_sign_16 = (x_sign_32 >> 16); + + if (std::isnan(x)) { + bits_ = x_sign_16 | uint16_t(__MLX_HALF_NAN__); + } else { + // Union + float_bits_fp16 inf_scale, zero_scale, magic_bits; + + // Find exponent bits and take the max supported by half + uint32_t x_expo_32 = in.u & uint32_t(0x7f800000); + uint32_t max_expo_32 = uint32_t(0x38800000); + x_expo_32 = x_expo_32 < max_expo_32 ? max_expo_32 : x_expo_32; + x_expo_32 += uint32_t(15) << 23; + + // Handle scaling to inf as needed + inf_scale.u = uint32_t(0x77800000); + zero_scale.u = uint32_t(0x08800000); + + // Combine with magic and let addition do rounding + magic_bits.u = x_expo_32; + magic_bits.f += (std::abs(x) * inf_scale.f) * zero_scale.f; + + // Take the lower 5 bits of the exponent + uint32_t x_expo_16 = ((magic_bits.u >> 13) & uint32_t(0x7c00)); + + // Collect the lower 12 bits which have the mantissa + uint32_t x_mant_16 = magic_bits.u & uint32_t(0x0fff); + + // Combine sign, exp and mantissa + bits_ = (x_sign_16 | uint16_t(x_expo_16 + x_mant_16)); + } + } + + // To float32 + operator float() const { + // Conversion following + // https://github.com/Maratyszcza/FP16/blob/master/include/fp16/fp16.h + + // Union + float_bits_fp16 out; + + uint32_t x_sign_32 = (bits_ << 16) & uint32_t(0x80000000); + uint32_t base = (bits_ << 16); + uint32_t two_base = base + base; + + uint32_t denorm_max = 1u << 27; + if (two_base < denorm_max) { + out.u = uint32_t(126) << 23; // magic mask + out.u |= (two_base >> 17); // Bits from fp16 + out.f -= 0.5f; // magic bias + } else { + out.u = uint32_t(0xE0) << 23; // exponent offset + out.u += (two_base >> 4); // Bits from fp16 + float out_unscaled = out.f; // Store value + out.u = uint32_t(0x7800000); // exponent scale + out.f *= out_unscaled; + } + + // Add sign + out.u |= x_sign_32; + + return out.f; + } +}; + +#define half_binop_base(__op__, __operator__, otype, atype, btype, ctype) \ + inline otype __operator__(atype lhs, btype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +#define half_binop_helper(__op__, __operator__, otype, itype, ctype) \ + inline otype __operator__(_MLX_Float16 lhs, itype rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline otype __operator__(itype lhs, _MLX_Float16 rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +// Operators +#define half_binop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, _MLX_Float16, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, float, float, float); \ + half_binop_helper(__op__, __operator__, double, double, double); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, bool, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint32_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, int64_t, float); \ + half_binop_helper(__op__, __operator__, _MLX_Float16, uint64_t, float); + +half_binop(+, operator+); +half_binop(-, operator-); +half_binop(*, operator*); +half_binop(/, operator/); + +#undef half_binop + +// Comparison ops +#define half_compop(__op__, __operator__) \ + half_binop_base( \ + __op__, __operator__, bool, _MLX_Float16, _MLX_Float16, float); \ + half_binop_helper(__op__, __operator__, bool, float, float); \ + half_binop_helper(__op__, __operator__, bool, double, double); \ + half_binop_helper(__op__, __operator__, bool, int32_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint32_t, float); \ + half_binop_helper(__op__, __operator__, bool, int64_t, float); \ + half_binop_helper(__op__, __operator__, bool, uint64_t, float); + +half_compop(>, operator>); +half_compop(<, operator<); +half_compop(>=, operator>=); +half_compop(<=, operator<=); +half_compop(==, operator==); +half_compop(!=, operator!=); + +#undef half_compop + +// Negative +inline _MLX_Float16 operator-(_MLX_Float16 lhs) { + return -static_cast(lhs); +} + +// Inplace ops +#define half_inplace_op(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, const float& rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } \ + inline float& __operator__(float& lhs, _MLX_Float16 rhs) { \ + lhs = lhs __op__ rhs; \ + return lhs; \ + } + +half_inplace_op(+, operator+=); +half_inplace_op(-, operator-=); +half_inplace_op(*, operator*=); +half_inplace_op(/, operator/=); + +#undef half_inplace_op + +// Bitwise ops + +#define half_bitop(__op__, __operator__) \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(_MLX_Float16 lhs, uint16_t rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs.bits_ __op__ rhs; \ + return out; \ + } \ + inline _MLX_Float16 __operator__(uint16_t lhs, _MLX_Float16 rhs) { \ + _MLX_Float16 out; \ + out.bits_ = lhs __op__ rhs.bits_; \ + return out; \ + } + +half_bitop(|, operator|); +half_bitop(&, operator&); +half_bitop(^, operator^); + +#undef half_bitop + +#define half_inplace_bitop(__op__, __operator__) \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, _MLX_Float16 rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs.bits_; \ + return lhs; \ + } \ + inline _MLX_Float16& __operator__(_MLX_Float16& lhs, uint16_t rhs) { \ + lhs.bits_ = lhs.bits_ __op__ rhs; \ + return lhs; \ + } + +half_inplace_bitop(|, operator|=); +half_inplace_bitop(&, operator&=); +half_inplace_bitop(^, operator^=); + +#undef half_inplace_bitop + +} // namespace mlx::core diff --git a/dist/include/mlx/types/half_types.h b/dist/include/mlx/types/half_types.h new file mode 100644 index 0000000..d9d6b9b --- /dev/null +++ b/dist/include/mlx/types/half_types.h @@ -0,0 +1,58 @@ +// Copyright © 2023 Apple Inc. + +#pragma once + +#ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#include +namespace mlx::core { +using ::float16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/fp16.h" +namespace mlx::core { +typedef struct _MLX_Float16 float16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC + +#ifdef __ARM_FEATURE_BF16 + +#include +namespace mlx::core { +using ::bfloat16_t; +} // namespace mlx::core + +#else + +#define ADD_HALF_BINOPS +#include "mlx/types/bf16.h" +namespace mlx::core { +typedef struct _MLX_BFloat16 bfloat16_t; +} // namespace mlx::core + +#endif // __ARM_FEATURE_BF16 + +#ifdef ADD_HALF_BINOPS +namespace mlx::core { + +// clang-format off +#define fp16_bf16_binop_helper(__op__, __operator__) \ + inline float __operator__(float16_t lhs, bfloat16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } \ + inline float __operator__(bfloat16_t lhs, float16_t rhs) { \ + return static_cast(lhs) __op__ static_cast(rhs); \ + } + +fp16_bf16_binop_helper(+, operator+) +fp16_bf16_binop_helper(-, operator-) +fp16_bf16_binop_helper(*, operator*) +fp16_bf16_binop_helper(/, operator/) +// clang-format on + +} // namespace mlx::core +#endif diff --git a/dist/include/mlx/types/limits.h b/dist/include/mlx/types/limits.h new file mode 100644 index 0000000..5f2b1e9 --- /dev/null +++ b/dist/include/mlx/types/limits.h @@ -0,0 +1,70 @@ +// Copyright © 2024 Apple Inc. +#pragma once + +#include +#include "mlx/types/half_types.h" + +namespace mlx::core { + +template +struct numeric_limits; + +template <> +struct numeric_limits : public std::numeric_limits {}; + +template <> +struct numeric_limits : public std::numeric_limits {}; + +template <> +struct numeric_limits { + private: + union half_or_bits { + uint16_t bits; + float16_t value; + }; + constexpr static float16_t bits_to_half(uint16_t v) { + return half_or_bits{v}.value; + } + + public: + constexpr static float16_t lowest() { + return bits_to_half(0xFBFF); + } + static constexpr float16_t max() { + return bits_to_half(0x7BFF); + } + static constexpr float16_t epsilon() { + return bits_to_half(0x1400); + } + static constexpr float16_t infinity() { + return bits_to_half(0x7C00); + } +}; + +template <> +struct numeric_limits { + private: + union bfloat_or_bits { + uint16_t bits; + bfloat16_t value; + }; + constexpr static bfloat16_t bits_to_bfloat(uint16_t v) { + return bfloat_or_bits{v}.value; + } + + public: + constexpr static bfloat16_t lowest() { + return bits_to_bfloat(0xFF7F); + } + static constexpr bfloat16_t max() { + return bits_to_bfloat(0x7F7F); + } + static constexpr bfloat16_t epsilon() { + return bits_to_bfloat(0x3C00); + } + static constexpr bfloat16_t infinity() { + return bits_to_bfloat(0x7F80); + } +}; + +} // namespace mlx::core diff --git a/dist/include/mlx/utils.h b/dist/include/mlx/utils.h new file mode 100644 index 0000000..dbf79a7 --- /dev/null +++ b/dist/include/mlx/utils.h @@ -0,0 +1,175 @@ +// Copyright © 2023-2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/array.h" +#include "mlx/device.h" +#include "mlx/dtype.h" +#include "mlx/stream.h" + +namespace mlx::core { + +using StreamOrDevice = std::variant; +Stream to_stream(StreamOrDevice s); +Stream to_stream(StreamOrDevice s, Device default_); + +struct StreamContext { + public: + StreamContext(StreamOrDevice s) : _stream(default_stream(default_device())) { + if (std::holds_alternative(s)) { + throw std::runtime_error( + "[StreamContext] Invalid argument, please specify a stream or device."); + } + auto _s = to_stream(s); + set_default_device(_s.device); + set_default_stream(_s); + } + + ~StreamContext() { + set_default_device(_stream.device); + set_default_stream(_stream); + } + + private: + Stream _stream; +}; + +struct PrintFormatter { + inline void print(std::ostream& os, bool val); + inline void print(std::ostream& os, int16_t val); + inline void print(std::ostream& os, uint16_t val); + inline void print(std::ostream& os, int32_t val); + inline void print(std::ostream& os, uint32_t val); + inline void print(std::ostream& os, int64_t val); + inline void print(std::ostream& os, uint64_t val); + inline void print(std::ostream& os, float16_t val); + inline void print(std::ostream& os, bfloat16_t val); + inline void print(std::ostream& os, float val); + inline void print(std::ostream& os, double val); + inline void print(std::ostream& os, complex64_t val); + + bool capitalize_bool{false}; +}; + +PrintFormatter& get_global_formatter(); + +/** Print the exception and then abort. */ +void abort_with_exception(const std::exception& error); + +/** Holds information about floating-point types. */ +struct finfo { + explicit finfo(Dtype dtype); + Dtype dtype; + double min; + double max; + double eps; +}; + +/** Holds information about integral types. */ +struct iinfo { + explicit iinfo(Dtype dtype); + Dtype dtype; + int64_t min; + uint64_t max; +}; + +/** The type from promoting the arrays' types with one another. */ +inline Dtype result_type(const array& a, const array& b) { + return promote_types(a.dtype(), b.dtype()); +} +inline Dtype result_type(const array& a, const array& b, const array& c) { + return promote_types(result_type(a, b), c.dtype()); +} +Dtype result_type(const std::vector& arrays); + +Shape broadcast_shapes(const Shape& s1, const Shape& s2); + +/** + * Returns the axis normalized to be in the range [0, ndim). + */ +int normalize_axis_index( + int axis, + int ndim, + const std::string& msg_prefix = ""); + +std::ostream& operator<<(std::ostream& os, const Device& d); +std::ostream& operator<<(std::ostream& os, const Stream& s); +std::ostream& operator<<(std::ostream& os, const Dtype& d); +std::ostream& operator<<(std::ostream& os, const Dtype::Kind& k); +std::ostream& operator<<(std::ostream& os, array a); +inline std::ostream& operator<<(std::ostream& os, const complex64_t& v) { + return os << v.real() << (v.imag() >= 0 ? "+" : "") << v.imag() << "j"; +} +inline std::ostream& operator<<(std::ostream& os, const float16_t& v) { + return os << static_cast(v); +} +inline std::ostream& operator<<(std::ostream& os, const bfloat16_t& v) { + return os << static_cast(v); +} + +template >> +inline std::ostream& operator<<(std::ostream& os, const Vec& v) { + os << "("; + for (auto it = v.begin(); it != v.end(); ++it) { + os << *it; + if (it != std::prev(v.end())) { + os << ","; + } + } + os << ")"; + return os; +} + +inline bool is_power_of_2(int n) { + return ((n & (n - 1)) == 0) && n != 0; +} + +inline int next_power_of_2(int n) { + if (is_power_of_2(n)) { + return n; + } + return pow(2, std::ceil(std::log2(n))); +} + +namespace env { + +int get_var(const char* name, int default_value); + +inline int bfs_max_width() { + static int bfs_max_width_ = get_var("MLX_BFS_MAX_WIDTH", 20); + return bfs_max_width_; +} + +inline int max_ops_per_buffer(int default_value) { + static int max_ops_per_buffer_ = + get_var("MLX_MAX_OPS_PER_BUFFER", default_value); + return max_ops_per_buffer_; +} + +inline int max_mb_per_buffer(int default_value) { + static int max_mb_per_buffer_ = + get_var("MLX_MAX_MB_PER_BUFFER", default_value); + return max_mb_per_buffer_; +} + +inline bool metal_fast_synch() { + static bool metal_fast_synch = get_var("MLX_METAL_FAST_SYNCH", 0); + return metal_fast_synch; +} + +inline bool enable_tf32() { + static bool enable_tf32_ = get_var("MLX_ENABLE_TF32", 1); + return enable_tf32_; +} + +inline int nccl_timeout(int default_value) { + static int nccl_timeout = get_var("MLX_NCCL_TIMEOUT", default_value); + return nccl_timeout; +} + +} // namespace env + +} // namespace mlx::core diff --git a/dist/include/mlx/version.h b/dist/include/mlx/version.h new file mode 100644 index 0000000..964630c --- /dev/null +++ b/dist/include/mlx/version.h @@ -0,0 +1,20 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#define MLX_VERSION_MAJOR 0 +#define MLX_VERSION_MINOR 30 +#define MLX_VERSION_PATCH 1 +#define MLX_VERSION_NUMERIC \ + (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) + +namespace mlx::core { + +/* A string representation of the MLX version in the format + * "major.minor.patch". + * + * For dev builds, the version will include the suffix ".devYYYYMMDD+hash" + */ +const char* version(); + +} // namespace mlx::core