// Copyright © 2024 Apple Inc. #pragma once #include #include "mlx/distributed/distributed.h" #include "mlx/utils.h" namespace mlx::core::distributed { array all_sum( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); array all_gather( const array& x, std::optional group = std::nullopt, StreamOrDevice S = {}); array send( const array& x, int dst, std::optional group = std::nullopt, StreamOrDevice s = {}); array recv( Shape shape, Dtype dtype, int src, std::optional group = std::nullopt, StreamOrDevice s = {}); array recv_like( const array& x, int src, std::optional group = std::nullopt, StreamOrDevice s = {}); array all_max( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); array all_min( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); array sum_scatter( const array& x, std::optional group = std::nullopt, StreamOrDevice s = {}); } // namespace mlx::core::distributed