// Copyright © 2024 Apple Inc. #include "mlx/io/load.h" #pragma once namespace mlx::core { struct FunctionTable; struct FunctionExporter { void operator()(const std::initializer_list& args) { this->operator()(Args(args)); } void operator()(const Args& args); void operator()(const Kwargs& kwargs); void operator()(const Args& args, const Kwargs& kwargs); void close(); FunctionExporter(const FunctionExporter&) = delete; FunctionExporter& operator=(const FunctionExporter&) = delete; FunctionExporter(FunctionExporter&& other) = default; private: friend FunctionExporter exporter( const std::string&, const std::function(const Args&)>&, bool shapeless); friend FunctionExporter exporter( const std::string&, const std::function(const Kwargs&)>&, bool shapeless); friend FunctionExporter exporter( const std::string&, const std::function(const Args&, const Kwargs&)>&, bool shapeless); friend FunctionExporter exporter( const ExportCallback&, const std::function(const Args&)>&, bool shapeless); friend FunctionExporter exporter( const ExportCallback&, const std::function(const Kwargs&)>&, bool shapeless); friend FunctionExporter exporter( const ExportCallback&, const std::function(const Args&, const Kwargs&)>&, bool shapeless); FunctionExporter( const std::string& file, std::function(const Args&, const Kwargs&)> fun, bool shapeless); FunctionExporter( const ExportCallback& callback, std::function(const Args&, const Kwargs&)> fun, bool shapeless); io::FileWriter os; ExportCallback callback; std::function(const Args&, const Kwargs& kwargs)> fun; void export_function(const Args& args, const Kwargs& kwargs); void export_with_callback( const std::vector& inputs, const std::vector& outputs, const std::vector& tape, const std::vector& kwarg_keys); std::unordered_map constants; int count{0}; bool closed{false}; std::shared_ptr ftable; }; struct ImportedFunction { std::vector operator()( const std::initializer_list& args) const { return this->operator()(Args(args)); } std::vector operator()(const Args& args) const; std::vector operator()(const Kwargs& kwargs) const; std::vector operator()(const Args& args, const Kwargs& kwargs) const; private: ImportedFunction(const std::string& file); friend ImportedFunction import_function(const std::string&); ImportedFunction(); std::shared_ptr ftable; }; } // namespace mlx::core