go-mlx/dist/include/metal_cpp/Metal/MTLComputePipeline.hpp
Snider 2292557fd6 chore: vendor MLX C headers for Go module consumers
dist/include/ contains the MLX and MLX-C headers needed for CGo
compilation. Without these, go-mlx cannot be used as a module
dependency (headers not found in module cache).

Libraries (dylib/metallib) are still gitignored — users build
those locally via cmake.

Co-Authored-By: Virgil <virgil@lethean.io>
2026-02-21 19:14:04 +00:00

439 lines
17 KiB
C++

//-------------------------------------------------------------------------------------------------------------------------------------------------------------
//
// Metal/MTLComputePipeline.hpp
//
// Copyright 2020-2025 Apple Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
//-------------------------------------------------------------------------------------------------------------------------------------------------------------
#pragma once
#include "../Foundation/Foundation.hpp"
#include "MTLAllocation.hpp"
#include "MTLDefines.hpp"
#include "MTLHeaderBridge.hpp"
#include "MTLPipeline.hpp"
#include "MTLPrivate.hpp"
#include "MTLTypes.hpp"
namespace MTL
{
class ComputePipelineDescriptor;
class ComputePipelineReflection;
class ComputePipelineState;
class Device;
class Function;
class FunctionHandle;
class IntersectionFunctionTable;
class IntersectionFunctionTableDescriptor;
class LinkedFunctions;
class PipelineBufferDescriptorArray;
class StageInputOutputDescriptor;
class VisibleFunctionTable;
class VisibleFunctionTableDescriptor;
}
namespace MTL4
{
class BinaryFunction;
}
namespace MTL
{
class ComputePipelineReflection : public NS::Referencing<ComputePipelineReflection>
{
public:
static ComputePipelineReflection* alloc();
NS::Array* arguments() const;
NS::Array* bindings() const;
ComputePipelineReflection* init();
};
class ComputePipelineDescriptor : public NS::Copying<ComputePipelineDescriptor>
{
public:
static ComputePipelineDescriptor* alloc();
NS::Array* binaryArchives() const;
PipelineBufferDescriptorArray* buffers() const;
Function* computeFunction() const;
ComputePipelineDescriptor* init();
NS::Array* insertLibraries() const;
NS::String* label() const;
LinkedFunctions* linkedFunctions() const;
NS::UInteger maxCallStackDepth() const;
NS::UInteger maxTotalThreadsPerThreadgroup() const;
NS::Array* preloadedLibraries() const;
Size requiredThreadsPerThreadgroup() const;
void reset();
void setBinaryArchives(const NS::Array* binaryArchives);
void setComputeFunction(const MTL::Function* computeFunction);
void setInsertLibraries(const NS::Array* insertLibraries);
void setLabel(const NS::String* label);
void setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions);
void setMaxCallStackDepth(NS::UInteger maxCallStackDepth);
void setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup);
void setPreloadedLibraries(const NS::Array* preloadedLibraries);
void setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup);
void setShaderValidation(MTL::ShaderValidation shaderValidation);
void setStageInputDescriptor(const MTL::StageInputOutputDescriptor* stageInputDescriptor);
void setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions);
void setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers);
void setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth);
ShaderValidation shaderValidation() const;
StageInputOutputDescriptor* stageInputDescriptor() const;
bool supportAddingBinaryFunctions() const;
bool supportIndirectCommandBuffers() const;
bool threadGroupSizeIsMultipleOfThreadExecutionWidth() const;
};
class ComputePipelineState : public NS::Referencing<ComputePipelineState, Allocation>
{
public:
Device* device() const;
FunctionHandle* functionHandle(const NS::String* name);
FunctionHandle* functionHandle(const MTL4::BinaryFunction* function);
FunctionHandle* functionHandle(const MTL::Function* function);
ResourceID gpuResourceID() const;
NS::UInteger imageblockMemoryLength(MTL::Size imageblockDimensions);
NS::String* label() const;
NS::UInteger maxTotalThreadsPerThreadgroup() const;
ComputePipelineState* newComputePipelineStateWithBinaryFunctions(const NS::Array* additionalBinaryFunctions, NS::Error** error);
ComputePipelineState* newComputePipelineState(const NS::Array* functions, NS::Error** error);
IntersectionFunctionTable* newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor);
VisibleFunctionTable* newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor);
ComputePipelineReflection* reflection() const;
Size requiredThreadsPerThreadgroup() const;
ShaderValidation shaderValidation() const;
NS::UInteger staticThreadgroupMemoryLength() const;
bool supportIndirectCommandBuffers() const;
NS::UInteger threadExecutionWidth() const;
};
}
_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineReflection::alloc()
{
return NS::Object::alloc<MTL::ComputePipelineReflection>(_MTL_PRIVATE_CLS(MTLComputePipelineReflection));
}
_MTL_INLINE NS::Array* MTL::ComputePipelineReflection::arguments() const
{
return Object::sendMessage<NS::Array*>(this, _MTL_PRIVATE_SEL(arguments));
}
_MTL_INLINE NS::Array* MTL::ComputePipelineReflection::bindings() const
{
return Object::sendMessage<NS::Array*>(this, _MTL_PRIVATE_SEL(bindings));
}
_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineReflection::init()
{
return NS::Object::init<MTL::ComputePipelineReflection>();
}
_MTL_INLINE MTL::ComputePipelineDescriptor* MTL::ComputePipelineDescriptor::alloc()
{
return NS::Object::alloc<MTL::ComputePipelineDescriptor>(_MTL_PRIVATE_CLS(MTLComputePipelineDescriptor));
}
_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::binaryArchives() const
{
return Object::sendMessage<NS::Array*>(this, _MTL_PRIVATE_SEL(binaryArchives));
}
_MTL_INLINE MTL::PipelineBufferDescriptorArray* MTL::ComputePipelineDescriptor::buffers() const
{
return Object::sendMessage<MTL::PipelineBufferDescriptorArray*>(this, _MTL_PRIVATE_SEL(buffers));
}
_MTL_INLINE MTL::Function* MTL::ComputePipelineDescriptor::computeFunction() const
{
return Object::sendMessage<MTL::Function*>(this, _MTL_PRIVATE_SEL(computeFunction));
}
_MTL_INLINE MTL::ComputePipelineDescriptor* MTL::ComputePipelineDescriptor::init()
{
return NS::Object::init<MTL::ComputePipelineDescriptor>();
}
_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::insertLibraries() const
{
return Object::sendMessage<NS::Array*>(this, _MTL_PRIVATE_SEL(insertLibraries));
}
_MTL_INLINE NS::String* MTL::ComputePipelineDescriptor::label() const
{
return Object::sendMessage<NS::String*>(this, _MTL_PRIVATE_SEL(label));
}
_MTL_INLINE MTL::LinkedFunctions* MTL::ComputePipelineDescriptor::linkedFunctions() const
{
return Object::sendMessage<MTL::LinkedFunctions*>(this, _MTL_PRIVATE_SEL(linkedFunctions));
}
_MTL_INLINE NS::UInteger MTL::ComputePipelineDescriptor::maxCallStackDepth() const
{
return Object::sendMessage<NS::UInteger>(this, _MTL_PRIVATE_SEL(maxCallStackDepth));
}
_MTL_INLINE NS::UInteger MTL::ComputePipelineDescriptor::maxTotalThreadsPerThreadgroup() const
{
return Object::sendMessage<NS::UInteger>(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup));
}
_MTL_INLINE NS::Array* MTL::ComputePipelineDescriptor::preloadedLibraries() const
{
return Object::sendMessage<NS::Array*>(this, _MTL_PRIVATE_SEL(preloadedLibraries));
}
_MTL_INLINE MTL::Size MTL::ComputePipelineDescriptor::requiredThreadsPerThreadgroup() const
{
return Object::sendMessage<MTL::Size>(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup));
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::reset()
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(reset));
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setBinaryArchives(const NS::Array* binaryArchives)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setBinaryArchives_), binaryArchives);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setComputeFunction(const MTL::Function* computeFunction)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setComputeFunction_), computeFunction);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setInsertLibraries(const NS::Array* insertLibraries)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setInsertLibraries_), insertLibraries);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setLabel(const NS::String* label)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setLabel_), label);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setLinkedFunctions(const MTL::LinkedFunctions* linkedFunctions)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setLinkedFunctions_), linkedFunctions);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setMaxCallStackDepth(NS::UInteger maxCallStackDepth)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setMaxCallStackDepth_), maxCallStackDepth);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setMaxTotalThreadsPerThreadgroup(NS::UInteger maxTotalThreadsPerThreadgroup)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setMaxTotalThreadsPerThreadgroup_), maxTotalThreadsPerThreadgroup);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setPreloadedLibraries(const NS::Array* preloadedLibraries)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setPreloadedLibraries_), preloadedLibraries);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setRequiredThreadsPerThreadgroup(MTL::Size requiredThreadsPerThreadgroup)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setRequiredThreadsPerThreadgroup_), requiredThreadsPerThreadgroup);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setShaderValidation(MTL::ShaderValidation shaderValidation)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setShaderValidation_), shaderValidation);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setStageInputDescriptor(const MTL::StageInputOutputDescriptor* stageInputDescriptor)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setStageInputDescriptor_), stageInputDescriptor);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setSupportAddingBinaryFunctions(bool supportAddingBinaryFunctions)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSupportAddingBinaryFunctions_), supportAddingBinaryFunctions);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setSupportIndirectCommandBuffers(bool supportIndirectCommandBuffers)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setSupportIndirectCommandBuffers_), supportIndirectCommandBuffers);
}
_MTL_INLINE void MTL::ComputePipelineDescriptor::setThreadGroupSizeIsMultipleOfThreadExecutionWidth(bool threadGroupSizeIsMultipleOfThreadExecutionWidth)
{
Object::sendMessage<void>(this, _MTL_PRIVATE_SEL(setThreadGroupSizeIsMultipleOfThreadExecutionWidth_), threadGroupSizeIsMultipleOfThreadExecutionWidth);
}
_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineDescriptor::shaderValidation() const
{
return Object::sendMessage<MTL::ShaderValidation>(this, _MTL_PRIVATE_SEL(shaderValidation));
}
_MTL_INLINE MTL::StageInputOutputDescriptor* MTL::ComputePipelineDescriptor::stageInputDescriptor() const
{
return Object::sendMessage<MTL::StageInputOutputDescriptor*>(this, _MTL_PRIVATE_SEL(stageInputDescriptor));
}
_MTL_INLINE bool MTL::ComputePipelineDescriptor::supportAddingBinaryFunctions() const
{
return Object::sendMessageSafe<bool>(this, _MTL_PRIVATE_SEL(supportAddingBinaryFunctions));
}
_MTL_INLINE bool MTL::ComputePipelineDescriptor::supportIndirectCommandBuffers() const
{
return Object::sendMessageSafe<bool>(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers));
}
_MTL_INLINE bool MTL::ComputePipelineDescriptor::threadGroupSizeIsMultipleOfThreadExecutionWidth() const
{
return Object::sendMessage<bool>(this, _MTL_PRIVATE_SEL(threadGroupSizeIsMultipleOfThreadExecutionWidth));
}
_MTL_INLINE MTL::Device* MTL::ComputePipelineState::device() const
{
return Object::sendMessage<MTL::Device*>(this, _MTL_PRIVATE_SEL(device));
}
_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const NS::String* name)
{
return Object::sendMessage<MTL::FunctionHandle*>(this, _MTL_PRIVATE_SEL(functionHandleWithName_), name);
}
_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const MTL4::BinaryFunction* function)
{
return Object::sendMessage<MTL::FunctionHandle*>(this, _MTL_PRIVATE_SEL(functionHandleWithBinaryFunction_), function);
}
_MTL_INLINE MTL::FunctionHandle* MTL::ComputePipelineState::functionHandle(const MTL::Function* function)
{
return Object::sendMessage<MTL::FunctionHandle*>(this, _MTL_PRIVATE_SEL(functionHandleWithFunction_), function);
}
_MTL_INLINE MTL::ResourceID MTL::ComputePipelineState::gpuResourceID() const
{
return Object::sendMessage<MTL::ResourceID>(this, _MTL_PRIVATE_SEL(gpuResourceID));
}
_MTL_INLINE NS::UInteger MTL::ComputePipelineState::imageblockMemoryLength(MTL::Size imageblockDimensions)
{
return Object::sendMessage<NS::UInteger>(this, _MTL_PRIVATE_SEL(imageblockMemoryLengthForDimensions_), imageblockDimensions);
}
_MTL_INLINE NS::String* MTL::ComputePipelineState::label() const
{
return Object::sendMessage<NS::String*>(this, _MTL_PRIVATE_SEL(label));
}
_MTL_INLINE NS::UInteger MTL::ComputePipelineState::maxTotalThreadsPerThreadgroup() const
{
return Object::sendMessage<NS::UInteger>(this, _MTL_PRIVATE_SEL(maxTotalThreadsPerThreadgroup));
}
_MTL_INLINE MTL::ComputePipelineState* MTL::ComputePipelineState::newComputePipelineStateWithBinaryFunctions(const NS::Array* additionalBinaryFunctions, NS::Error** error)
{
return Object::sendMessage<MTL::ComputePipelineState*>(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithBinaryFunctions_error_), additionalBinaryFunctions, error);
}
_MTL_INLINE MTL::ComputePipelineState* MTL::ComputePipelineState::newComputePipelineState(const NS::Array* functions, NS::Error** error)
{
return Object::sendMessage<MTL::ComputePipelineState*>(this, _MTL_PRIVATE_SEL(newComputePipelineStateWithAdditionalBinaryFunctions_error_), functions, error);
}
_MTL_INLINE MTL::IntersectionFunctionTable* MTL::ComputePipelineState::newIntersectionFunctionTable(const MTL::IntersectionFunctionTableDescriptor* descriptor)
{
return Object::sendMessage<MTL::IntersectionFunctionTable*>(this, _MTL_PRIVATE_SEL(newIntersectionFunctionTableWithDescriptor_), descriptor);
}
_MTL_INLINE MTL::VisibleFunctionTable* MTL::ComputePipelineState::newVisibleFunctionTable(const MTL::VisibleFunctionTableDescriptor* descriptor)
{
return Object::sendMessage<MTL::VisibleFunctionTable*>(this, _MTL_PRIVATE_SEL(newVisibleFunctionTableWithDescriptor_), descriptor);
}
_MTL_INLINE MTL::ComputePipelineReflection* MTL::ComputePipelineState::reflection() const
{
return Object::sendMessage<MTL::ComputePipelineReflection*>(this, _MTL_PRIVATE_SEL(reflection));
}
_MTL_INLINE MTL::Size MTL::ComputePipelineState::requiredThreadsPerThreadgroup() const
{
return Object::sendMessage<MTL::Size>(this, _MTL_PRIVATE_SEL(requiredThreadsPerThreadgroup));
}
_MTL_INLINE MTL::ShaderValidation MTL::ComputePipelineState::shaderValidation() const
{
return Object::sendMessage<MTL::ShaderValidation>(this, _MTL_PRIVATE_SEL(shaderValidation));
}
_MTL_INLINE NS::UInteger MTL::ComputePipelineState::staticThreadgroupMemoryLength() const
{
return Object::sendMessage<NS::UInteger>(this, _MTL_PRIVATE_SEL(staticThreadgroupMemoryLength));
}
_MTL_INLINE bool MTL::ComputePipelineState::supportIndirectCommandBuffers() const
{
return Object::sendMessageSafe<bool>(this, _MTL_PRIVATE_SEL(supportIndirectCommandBuffers));
}
_MTL_INLINE NS::UInteger MTL::ComputePipelineState::threadExecutionWidth() const
{
return Object::sendMessage<NS::UInteger>(this, _MTL_PRIVATE_SEL(threadExecutionWidth));
}