// Copyright © 2024 Apple Inc. #pragma once #include #include #include #include #include "mlx/array.h" namespace mlx::core { using Args = std::vector; using Kwargs = std::unordered_map; // Possible types for a Primitive's state using StateT = std::variant< bool, int, size_t, float, double, Dtype, Shape, Strides, std::vector, std::vector, std::vector>, std::vector>, std::optional, std::string>; using ExportCallbackInput = std::unordered_map< std::string, std::variant< std::vector>, std::vector>, std::vector>, std::vector, std::string>>; using ExportCallback = std::function; struct FunctionExporter; /** * Make an exporter to save multiple traces of a given function to * the same file. */ FunctionExporter exporter( const std::string& file, const std::function(const Args&)>& fun, bool shapeless = false); FunctionExporter exporter( const std::string& file, const std::function(const Kwargs&)>& fun, bool shapeless = false); FunctionExporter exporter( const std::string& path, const std::function(const Args&, const Kwargs&)>& fun, bool shapeless = false); /** * Export a function to a file. */ void export_function( const std::string& file, const std::function(const Args&)>& fun, const Args& args, bool shapeless = false); void export_function( const std::string& file, const std::function(const Kwargs&)>& fun, const Kwargs& kwargs, bool shapeless = false); void export_function( const std::string& file, const std::function(const Args&, const Kwargs&)>& fun, const Args& args, const Kwargs& kwargs, bool shapeless = false); struct ImportedFunction; /** * Import a function from a file. */ ImportedFunction import_function(const std::string& file); /** * Make an exporter to export multiple traces of a given function with the same * callback. */ FunctionExporter exporter( const ExportCallback& callback, const std::function(const Args&)>& fun, bool shapeless = false); FunctionExporter exporter( const ExportCallback& callback, const std::function(const Kwargs&)>& fun, bool shapeless = false); FunctionExporter exporter( const ExportCallback& callback, const std::function(const Args&, const Kwargs&)>& fun, bool shapeless = false); /** * Export a function with a callback. */ void export_function( const ExportCallback& callback, const std::function(const Args&)>& fun, const Args& args, bool shapeless = false); void export_function( const ExportCallback& callback, const std::function(const Kwargs&)>& fun, const Kwargs& kwargs, bool shapeless = false); void export_function( const ExportCallback& callback, const std::function(const Args&, const Kwargs&)>& fun, const Args& args, const Kwargs& kwargs, bool shapeless = false); } // namespace mlx::core #include "mlx/export_impl.h"