// Copyright © 2023-2024 Apple Inc. #pragma once #include #define LAPACK_COMPLEX_CUSTOM #define lapack_complex_float std::complex #define lapack_complex_double std::complex #define lapack_complex_float_real(z) ((z).real()) #define lapack_complex_float_imag(z) ((z).imag()) #define lapack_complex_double_real(z) ((z).real()) #define lapack_complex_double_imag(z) ((z).imag()) #ifdef MLX_USE_ACCELERATE #include #else #include #include #endif #if defined(LAPACK_GLOBAL) || defined(LAPACK_NAME) // This is to work around a change in the function signatures of lapack >= 3.9.1 // where functions taking char* also include a strlen argument, see a similar // change in OpenCV: // https://github.com/opencv/opencv/blob/1eb061f89de0fb85c4c75a2deeb0f61a961a63ad/cmake/OpenCVFindLAPACK.cmake#L57 #define MLX_LAPACK_FUNC(f) LAPACK_##f #else #define MLX_LAPACK_FUNC(f) f##_ #endif #define INSTANTIATE_LAPACK_REAL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ } \ } INSTANTIATE_LAPACK_REAL(geqrf) INSTANTIATE_LAPACK_REAL(orgqr) INSTANTIATE_LAPACK_REAL(syevd) INSTANTIATE_LAPACK_REAL(potrf) INSTANTIATE_LAPACK_REAL(getrf) INSTANTIATE_LAPACK_REAL(getri) INSTANTIATE_LAPACK_REAL(trtri) #define INSTANTIATE_LAPACK_COMPLEX(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ } \ } INSTANTIATE_LAPACK_COMPLEX(heevd) #define INSTANTIATE_LAPACK_ALL(FUNC) \ template \ void FUNC(Args... args) { \ if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(s##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v) { \ MLX_LAPACK_FUNC(d##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(c##FUNC)(std::forward(args)...); \ } else if constexpr (std::is_same_v>) { \ MLX_LAPACK_FUNC(z##FUNC)(std::forward(args)...); \ } \ } INSTANTIATE_LAPACK_ALL(geev) INSTANTIATE_LAPACK_ALL(gesdd)