// Copyright © 2024 Apple Inc. #pragma once #include "mlx/distributed/distributed.h" #include "mlx/distributed/distributed_impl.h" #include "mlx/primitives.h" namespace mlx::core::distributed { class DistPrimitive : public Primitive { public: DistPrimitive(Stream stream, Group group) : Primitive(stream), group_(group) {} const Group& group() const { return group_; } private: Group group_; }; class AllReduce : public DistPrimitive { public: enum ReduceType { And, Or, Sum, Prod, Min, Max }; AllReduce(Stream stream, Group group, ReduceType reduce_type) : DistPrimitive(stream, group), reduce_type_(reduce_type) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; const char* name() const override { switch (reduce_type_) { case And: return "And AllReduce"; case Or: return "Or AllReduce"; case Sum: return "Sum AllReduce"; case Prod: return "Prod AllReduce"; case Min: return "Min AllReduce"; case Max: return "Max AllReduce"; } return ""; } private: ReduceType reduce_type_; }; class AllGather : public DistPrimitive { public: AllGather(Stream stream, Group group) : DistPrimitive(stream, group) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; std::vector jvp( const std::vector& primals, const std::vector& tangents, const std::vector& argnums) override; std::vector vjp( const std::vector& primals, const std::vector& cotangents, const std::vector& argnums, const std::vector& outputs) override; DEFINE_NAME(AllGather); }; class Send : public DistPrimitive { public: Send(Stream stream, Group group, int dst) : DistPrimitive(stream, group), dst_(dst) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; std::pair, std::vector> vmap( const std::vector& inputs, const std::vector& axes) override; DEFINE_NAME(Send); private: int dst_; }; class Recv : public DistPrimitive { public: Recv(Stream stream, Group group, int src) : DistPrimitive(stream, group), src_(src) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_NAME(Recv); private: int src_; }; class ReduceScatter : public DistPrimitive { public: enum ReduceType { Sum, Min, Max }; ReduceScatter(Stream stream, Group group, ReduceType reduce_type) : DistPrimitive(stream, group), reduce_type_(reduce_type) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; const char* name() const override { switch (reduce_type_) { case Sum: return "Sum ReduceScatter"; case Min: return "Min ReduceScatter"; case Max: return "Max ReduceScatter"; } return ""; } private: ReduceType reduce_type_; }; } // namespace mlx::core::distributed