Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions src/TiledArray/math/linalg/basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,17 @@ inline void vec_multiply(
a1.array() *= a2.array();
}

template <typename Derived>
inline auto clone(const Eigen::MatrixBase<Derived>& a) {
return a.eval();
}

template <typename XprType1, int BlockRows1, int BlockCols1, bool InnerPanel1>
inline auto clone(
const Eigen::Block<XprType1, BlockRows1, BlockCols1, InnerPanel1>& a) {
return a.eval();
}

template <typename Derived, typename S>
inline void scale(Eigen::MatrixBase<Derived>& a, S scaling_factor) {
using numeric_type = typename Eigen::MatrixBase<Derived>::value_type;
Expand Down Expand Up @@ -239,6 +250,21 @@ inline auto norm2(
return m.template lpNorm<2>();
}

template <typename Derived>
inline auto volume(const Eigen::MatrixBase<Derived>& m) {
return m.size();
}

template <typename Derived>
inline auto abs_min(const Eigen::MatrixBase<Derived>& m) {
return m.array().abs().minCoeff();
}

template <typename Derived>
inline auto abs_max(const Eigen::MatrixBase<Derived>& m) {
return m.array().abs().maxCoeff();
}

} // namespace Eigen

#ifndef TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG
Expand All @@ -253,12 +279,12 @@ inline auto norm2(
return scalapack::FN; \
return non_distributed::FN;
#elif (TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK)
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG || \
TiledArray::math::linalg::detail::prefer_distributed(MATRIX)) \
return TiledArray::math::linalg::ttg::FN; \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG || \
TiledArray::math::linalg::detail::prefer_distributed(MATRIX)) \
return TiledArray::math::linalg::ttg::FN; \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
return non_distributed::FN;
#elif !TILEDARRAY_HAS_TTG && TILEDARRAY_HAS_SCALAPACK
Expand All @@ -271,11 +297,11 @@ inline auto norm2(
return scalapack::FN; \
return non_distributed::FN;
#else // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
TA_EXCEPTION("TTG linear algebra backend is not available"); \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
#define TILEDARRAY_MATH_LINALG_DISPATCH_W_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
TA_EXCEPTION("TTG linear algebra backend is not available"); \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
return non_distributed::FN;
#endif // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
Expand All @@ -297,12 +323,12 @@ inline auto norm2(
return scalapack::FN; \
return non_distributed::FN;
#elif TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
TA_EXCEPTION(TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG_STRINGIFY( \
FN) " is not provided by the TTG backend"); \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
TA_EXCEPTION(TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG_STRINGIFY( \
FN) " is not provided by the TTG backend"); \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
return non_distributed::FN;
#elif !TILEDARRAY_HAS_TTG && TILEDARRAY_HAS_SCALAPACK
Expand All @@ -315,11 +341,11 @@ inline auto norm2(
return scalapack::FN; \
return non_distributed::FN;
#else // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
TA_EXCEPTION("TTG linear algebra backend is not available"); \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
#define TILEDARRAY_MATH_LINALG_DISPATCH_WO_TTG(FN, MATRIX) \
TA_MAX_THREADS; \
if (get_linalg_backend() == LinearAlgebraBackend::TTG) \
TA_EXCEPTION("TTG linear algebra backend is not available"); \
if (get_linalg_backend() == LinearAlgebraBackend::ScaLAPACK) \
TA_EXCEPTION("ScaLAPACK linear algebra backend is not available"); \
return non_distributed::FN;
#endif // !TILEDARRAY_HAS_TTG && !TILEDARRAY_HAS_SCALAPACK
Expand Down
15 changes: 8 additions & 7 deletions src/TiledArray/math/solvers/conjgrad.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <TiledArray/math/linalg/basic.h>
#include <TiledArray/math/solvers/diis.h>
#include "TiledArray/dist_array.h"
#include "TiledArray/type_traits.h"

namespace TiledArray::math {

Expand All @@ -44,8 +45,8 @@ namespace TiledArray::math {
/// stand-alone functions:
/// \li <tt> std::size_t volume(const D&) </tt> (returns the total number of elements)
/// \li <tt> D clone(const D&) </tt>, returns a deep copy
/// \li <tt> value_type minabs_value(const D&) </tt>
/// \li <tt> value_type maxabs_value(const D&) </tt>
/// \li <tt> value_type abs_min(const D&) </tt>
/// \li <tt> value_type abs_max(const D&) </tt>
/// \li <tt> void vec_multiply(D& a, const D& b) </tt> (element-wise multiply
/// of \c a by \c b )
/// \li <tt> value_type inner_product(const D& a, const D& b) </tt>
Expand All @@ -60,7 +61,7 @@ namespace TiledArray::math {
// clang-format on
template <typename D, typename F>
struct ConjugateGradientSolver {
typedef typename D::numeric_type value_type;
typedef TiledArray::detail::numeric_t<D> value_type;

/// \param a object of type F
/// \param b RHS
Expand All @@ -73,8 +74,8 @@ struct ConjugateGradientSolver {
value_type convergence_target = -1.0) {
std::size_t n = volume(preconditioner);

const bool use_diis = false;
DIIS<D> diis;
constexpr bool use_diis = false;
std::conditional_t<use_diis, DIIS<D>, char> diis{};

// solution vector
D XX_i;
Expand Down Expand Up @@ -120,7 +121,7 @@ struct ConjugateGradientSolver {
scale(RR_i, -1.0);
axpy(RR_i, 1.0, b); // RR_i = b - a(XX_i)

if (use_diis) diis.extrapolate(XX_i, RR_i, true);
if constexpr (use_diis) diis.extrapolate(XX_i, RR_i, true);

// z_0 = D^-1 . r_0
ZZ_i = RR_i;
Expand All @@ -144,7 +145,7 @@ struct ConjugateGradientSolver {
// r_i -= alpha_i Ap_i
axpy(RR_i, -alpha_i, APP_i);

if (use_diis) diis.extrapolate(XX_i, RR_i, true);
if constexpr (use_diis) diis.extrapolate(XX_i, RR_i, true);

const value_type r_ip1_norm = norm2(RR_i) / rhs_size;
if (r_ip1_norm < convergence_target) {
Expand Down
3 changes: 2 additions & 1 deletion src/TiledArray/math/solvers/diis.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <TiledArray/math/linalg/basic.h>
#include "TiledArray/dist_array.h"
#include "TiledArray/external/eigen.h"
#include "TiledArray/type_traits.h"

#include <Eigen/QR>
#include <deque>
Expand Down Expand Up @@ -82,7 +83,7 @@ namespace TiledArray::math {
template <typename D>
class DIIS {
public:
typedef typename D::numeric_type value_type;
typedef TiledArray::detail::numeric_t<D> value_type;
typedef typename TiledArray::detail::scalar_t<value_type> scalar_type;
typedef Eigen::Matrix<value_type, Eigen::Dynamic, Eigen::Dynamic,
Eigen::RowMajor>
Expand Down
2 changes: 1 addition & 1 deletion src/TiledArray/tile_interface/clone.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ namespace TiledArray {
/// \tparam Arg The tile argument type
/// \param arg The tile argument to be permuted
/// \return A (deep) copy of \c arg
template <typename Arg>
template <typename Arg, typename = decltype(std::declval<const Arg&>().clone())>
inline auto clone(const Arg& arg) {
return arg.clone();
}
Expand Down
14 changes: 10 additions & 4 deletions src/TiledArray/tile_op/tile_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,8 @@ inline auto min(const Arg& arg) {
/// \tparam Arg The tile argument type
/// \param arg The argument to find the maximum
/// \return A scalar that is equal to <tt>abs(max(arg))</tt>
template <typename Arg>
template <typename Arg,
typename = decltype(std::declval<const Arg&>().abs_max())>
inline auto abs_max(const Arg& arg) {
return arg.abs_max();
}
Expand All @@ -979,7 +980,8 @@ inline auto abs_max(const Arg& arg) {
/// \tparam Arg The tile argument type
/// \param arg The argument to find the minimum
/// \return A scalar that is equal to <tt>abs(min(arg))</tt>
template <typename Arg>
template <typename Arg,
typename = decltype(std::declval<const Arg&>().abs_min())>
inline auto abs_min(const Arg& arg) {
return arg.abs_min();
}
Expand All @@ -991,7 +993,9 @@ inline auto abs_min(const Arg& arg) {
/// \param left The left-hand argument tile
/// \param right The right-hand argument tile
/// \return A scalar that is equal to <tt>sum_i left[i] * right[i]</tt>
template <typename Left, typename Right>
template <typename Left, typename Right,
typename = decltype(std::declval<const Left&>().dot(
std::declval<const Right&>()))>
inline auto dot(const Left& left, const Right& right) {
return left.dot(right);
}
Expand All @@ -1003,7 +1007,9 @@ inline auto dot(const Left& left, const Right& right) {
/// \param left The left-hand argument tile
/// \param right The right-hand argument tile
/// \return A scalar that is equal to <tt>sum_i conj(left[i]) * right[i]</tt>
template <typename Left, typename Right>
template <typename Left, typename Right,
typename = decltype(std::declval<const Left&>().inner_product(
std::declval<const Right&>()))>
inline auto inner_product(const Left& left, const Right& right) {
return left.inner_product(right);
}
Expand Down
53 changes: 53 additions & 0 deletions tests/solvers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,49 @@ struct validate<DistArray<Tile, Policy>> {
}
};

// Eigen specializations

template <>
struct make_Ax<Eigen::VectorXd> {
using T = Eigen::VectorXd;

struct Ax {
Ax() : A_(3, 3) { A_ << 1, 2, 3, 2, 5, 8, 3, 8, 15; }
void operator()(const T& x, T& result) const { result = A_ * x; }
Eigen::MatrixXd A_;
};
Ax operator()() const { return Ax{}; }
};

template <>
struct make_b<Eigen::VectorXd> {
using T = Eigen::VectorXd;

T operator()() const {
T result(3);
result << 1, 4, 0;
return result;
}
};

template <>
struct make_pc<Eigen::VectorXd> {
using T = Eigen::VectorXd;

T operator()() const { return T::Ones(3); }
};

template <>
struct validate<Eigen::VectorXd> {
using T = Eigen::VectorXd;

bool operator()(const T& x) const {
T ref(3);
ref << -6.5, 9., -3.5;
return (x - ref).norm() < 1e-11;
}
};

BOOST_AUTO_TEST_SUITE(solvers)

BOOST_AUTO_TEST_CASE_TEMPLATE(conjugate_gradient, Array, array_types) {
Expand All @@ -178,4 +221,14 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(conjugate_gradient, Array, array_types) {
BOOST_CHECK(validate<Array>{}(x));
}

BOOST_AUTO_TEST_CASE(conjugate_gradient_eigen) {
using T = Eigen::VectorXd;
auto Ax = make_Ax<T>{}();
auto b = make_b<T>{}();
auto pc = make_pc<T>{}();
T x;
ConjugateGradientSolver<T, decltype(Ax)>{}(Ax, b, x, pc, 1e-11);
BOOST_CHECK(validate<T>{}(x));
}

BOOST_AUTO_TEST_SUITE_END()
Loading