10 #ifndef EIGEN_PASTIXSUPPORT_H 11 #define EIGEN_PASTIXSUPPORT_H 14 #define PASTIX_COMPLEX COMPLEX 15 #define PASTIX_DCOMPLEX DCOMPLEX 17 #define PASTIX_COMPLEX std::complex<float> 18 #define PASTIX_DCOMPLEX std::complex<double> 31 template<
typename _MatrixType,
bool IsStrSym = false>
class PastixLU;
32 template<
typename _MatrixType,
int Options>
class PastixLLT;
33 template<
typename _MatrixType,
int Options>
class PastixLDLT;
38 template<
class Pastix>
struct pastix_traits;
40 template<
typename _MatrixType>
41 struct pastix_traits<
PastixLU<_MatrixType> >
43 typedef _MatrixType MatrixType;
44 typedef typename _MatrixType::Scalar Scalar;
45 typedef typename _MatrixType::RealScalar RealScalar;
46 typedef typename _MatrixType::Index Index;
49 template<
typename _MatrixType,
int Options>
50 struct pastix_traits<
PastixLLT<_MatrixType,Options> >
52 typedef _MatrixType MatrixType;
53 typedef typename _MatrixType::Scalar Scalar;
54 typedef typename _MatrixType::RealScalar RealScalar;
55 typedef typename _MatrixType::Index Index;
58 template<
typename _MatrixType,
int Options>
59 struct pastix_traits<
PastixLDLT<_MatrixType,Options> >
61 typedef _MatrixType MatrixType;
62 typedef typename _MatrixType::Scalar Scalar;
63 typedef typename _MatrixType::RealScalar RealScalar;
64 typedef typename _MatrixType::Index Index;
67 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
float *vals,
int *perm,
int * invp,
float *x,
int nbrhs,
int *iparm,
double *dparm)
69 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
70 if (nbrhs == 0) {x = NULL; nbrhs=1;}
71 s_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
74 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx,
double *vals,
int *perm,
int * invp,
double *x,
int nbrhs,
int *iparm,
double *dparm)
76 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
77 if (nbrhs == 0) {x = NULL; nbrhs=1;}
78 d_pastix(pastix_data, pastix_comm, n, ptr, idx, vals, perm, invp, x, nbrhs, iparm, dparm);
81 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx, std::complex<float> *vals,
int *perm,
int * invp, std::complex<float> *x,
int nbrhs,
int *iparm,
double *dparm)
83 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
84 if (nbrhs == 0) {x = NULL; nbrhs=1;}
85 c_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<PASTIX_COMPLEX*>(vals), perm, invp, reinterpret_cast<PASTIX_COMPLEX*>(x), nbrhs, iparm, dparm);
88 void eigen_pastix(pastix_data_t **pastix_data,
int pastix_comm,
int n,
int *ptr,
int *idx, std::complex<double> *vals,
int *perm,
int * invp, std::complex<double> *x,
int nbrhs,
int *iparm,
double *dparm)
90 if (n == 0) { ptr = NULL; idx = NULL; vals = NULL; }
91 if (nbrhs == 0) {x = NULL; nbrhs=1;}
92 z_pastix(pastix_data, pastix_comm, n, ptr, idx, reinterpret_cast<PASTIX_DCOMPLEX*>(vals), perm, invp, reinterpret_cast<PASTIX_DCOMPLEX*>(x), nbrhs, iparm, dparm);
96 template <
typename MatrixType>
97 void c_to_fortran_numbering (MatrixType& mat)
99 if ( !(mat.outerIndexPtr()[0]) )
102 for(i = 0; i <= mat.rows(); ++i)
103 ++mat.outerIndexPtr()[i];
104 for(i = 0; i < mat.nonZeros(); ++i)
105 ++mat.innerIndexPtr()[i];
110 template <
typename MatrixType>
111 void fortran_to_c_numbering (MatrixType& mat)
114 if ( mat.outerIndexPtr()[0] == 1 )
117 for(i = 0; i <= mat.rows(); ++i)
118 --mat.outerIndexPtr()[i];
119 for(i = 0; i < mat.nonZeros(); ++i)
120 --mat.innerIndexPtr()[i];
127 template <
class Derived>
128 class PastixBase : internal::noncopyable
131 typedef typename internal::pastix_traits<Derived>::MatrixType _MatrixType;
132 typedef _MatrixType MatrixType;
133 typedef typename MatrixType::Scalar Scalar;
134 typedef typename MatrixType::RealScalar RealScalar;
135 typedef typename MatrixType::Index Index;
141 PastixBase() : m_initisOk(false), m_analysisIsOk(false), m_factorizationIsOk(false), m_isInitialized(false), m_pastixdata(0), m_size(0)
155 template<
typename Rhs>
156 inline const internal::solve_retval<PastixBase, Rhs>
159 eigen_assert(m_isInitialized &&
"Pastix solver is not initialized.");
160 eigen_assert(rows()==b.rows()
161 &&
"PastixBase::solve(): invalid number of rows of the right hand side matrix b");
162 return internal::solve_retval<PastixBase, Rhs>(*
this, b.derived());
165 template<
typename Rhs,
typename Dest>
170 return *
static_cast<Derived*
>(
this);
172 const Derived& derived()
const 174 return *
static_cast<const Derived*
>(
this);
191 int& iparm(
int idxparam)
193 return m_iparm(idxparam);
209 double& dparm(
int idxparam)
211 return m_dparm(idxparam);
214 inline Index cols()
const {
return m_size; }
215 inline Index rows()
const {
return m_size; }
227 eigen_assert(m_isInitialized &&
"Decomposition is not initialized.");
235 template<
typename Rhs>
236 inline const internal::sparse_solve_retval<PastixBase, Rhs>
239 eigen_assert(m_isInitialized &&
"Pastix LU, LLT or LDLT is not initialized.");
240 eigen_assert(rows()==b.
rows()
241 &&
"PastixBase::solve(): invalid number of rows of the right hand side matrix b");
242 return internal::sparse_solve_retval<PastixBase, Rhs>(*
this, b.
derived());
251 void analyzePattern(ColSpMatrix& mat);
254 void factorize(ColSpMatrix& mat);
259 eigen_assert(m_initisOk &&
"The Pastix structure should be allocated first");
260 m_iparm(IPARM_START_TASK) = API_TASK_CLEAN;
261 m_iparm(IPARM_END_TASK) = API_TASK_CLEAN;
262 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
263 m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
266 void compute(ColSpMatrix& mat);
270 int m_factorizationIsOk;
271 bool m_isInitialized;
273 mutable pastix_data_t *m_pastixdata;
286 template <
class Derived>
287 void PastixBase<Derived>::init()
291 m_dparm.setZero(DPARM_SIZE);
293 m_iparm(IPARM_MODIFY_PARAMETER) = API_NO;
294 pastix(&m_pastixdata, MPI_COMM_WORLD,
296 0, 0, 0, 1, m_iparm.data(), m_dparm.data());
298 m_iparm[IPARM_MATRIX_VERIFICATION] = API_NO;
299 m_iparm[IPARM_VERBOSE] = 2;
300 m_iparm[IPARM_ORDERING] = API_ORDER_SCOTCH;
301 m_iparm[IPARM_INCOMPLETE] = API_NO;
302 m_iparm[IPARM_OOC_LIMIT] = 2000;
303 m_iparm[IPARM_RHS_MAKING] = API_RHS_B;
304 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
306 m_iparm(IPARM_START_TASK) = API_TASK_INIT;
307 m_iparm(IPARM_END_TASK) = API_TASK_INIT;
308 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, 0, 0, 0, (Scalar*)0,
309 0, 0, 0, 0, m_iparm.data(), m_dparm.data());
312 if(m_iparm(IPARM_ERROR_NUMBER)) {
322 template <
class Derived>
323 void PastixBase<Derived>::compute(ColSpMatrix& mat)
325 eigen_assert(mat.
rows() == mat.
cols() &&
"The input matrix should be squared");
330 m_iparm(IPARM_MATRIX_VERIFICATION) = API_NO;
331 m_isInitialized = m_factorizationIsOk;
335 template <
class Derived>
336 void PastixBase<Derived>::analyzePattern(ColSpMatrix& mat)
338 eigen_assert(m_initisOk &&
"The initialization of PaSTiX failed");
345 m_perm.resize(m_size);
346 m_invp.resize(m_size);
348 m_iparm(IPARM_START_TASK) = API_TASK_ORDERING;
349 m_iparm(IPARM_END_TASK) = API_TASK_ANALYSE;
351 mat.
valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
354 if(m_iparm(IPARM_ERROR_NUMBER))
357 m_analysisIsOk =
false;
362 m_analysisIsOk =
true;
366 template <
class Derived>
367 void PastixBase<Derived>::factorize(ColSpMatrix& mat)
370 eigen_assert(m_analysisIsOk &&
"The analysis phase should be called before the factorization phase");
371 m_iparm(IPARM_START_TASK) = API_TASK_NUMFACT;
372 m_iparm(IPARM_END_TASK) = API_TASK_NUMFACT;
376 mat.
valuePtr(), m_perm.data(), m_invp.data(), 0, 0, m_iparm.data(), m_dparm.data());
379 if(m_iparm(IPARM_ERROR_NUMBER))
382 m_factorizationIsOk =
false;
383 m_isInitialized =
false;
388 m_factorizationIsOk =
true;
389 m_isInitialized =
true;
394 template<
typename Base>
395 template<
typename Rhs,
typename Dest>
398 eigen_assert(m_isInitialized &&
"The matrix should be factorized first");
400 THIS_METHOD_IS_ONLY_FOR_COLUMN_MAJOR_MATRICES);
405 for (
int i = 0; i < b.cols(); i++){
406 m_iparm[IPARM_START_TASK] = API_TASK_SOLVE;
407 m_iparm[IPARM_END_TASK] = API_TASK_REFINE;
409 internal::eigen_pastix(&m_pastixdata, MPI_COMM_WORLD, x.rows(), 0, 0, 0,
410 m_perm.data(), m_invp.data(), &x(0, i), rhs, m_iparm.data(), m_dparm.data());
416 return m_iparm(IPARM_ERROR_NUMBER)==0;
438 template<
typename _MatrixType,
bool IsStrSym>
439 class PastixLU :
public PastixBase< PastixLU<_MatrixType> >
442 typedef _MatrixType MatrixType;
443 typedef PastixBase<PastixLU<MatrixType> > Base;
445 typedef typename MatrixType::Index Index;
453 PastixLU(
const MatrixType& matrix):Base()
465 m_structureIsUptodate =
false;
467 grabMatrix(matrix, temp);
477 m_structureIsUptodate =
false;
479 grabMatrix(matrix, temp);
480 Base::analyzePattern(temp);
491 grabMatrix(matrix, temp);
492 Base::factorize(temp);
498 m_structureIsUptodate =
false;
499 m_iparm(IPARM_SYM) = API_SYM_NO;
500 m_iparm(IPARM_FACTORIZATION) = API_FACT_LU;
503 void grabMatrix(
const MatrixType& matrix, ColSpMatrix& out)
509 if(!m_structureIsUptodate)
512 m_transposedStructure = matrix.transpose();
515 for (Index j=0; j<m_transposedStructure.outerSize(); ++j)
516 for(
typename ColSpMatrix::InnerIterator it(m_transposedStructure, j); it; ++it)
519 m_structureIsUptodate =
true;
522 out = m_transposedStructure + matrix;
524 internal::c_to_fortran_numbering(out);
530 ColSpMatrix m_transposedStructure;
531 bool m_structureIsUptodate;
548 template<
typename _MatrixType,
int _UpLo>
549 class PastixLLT :
public PastixBase< PastixLLT<_MatrixType, _UpLo> >
552 typedef _MatrixType MatrixType;
553 typedef PastixBase<PastixLLT<MatrixType, _UpLo> > Base;
557 enum { UpLo = _UpLo };
563 PastixLLT(
const MatrixType& matrix):Base()
575 grabMatrix(matrix, temp);
586 grabMatrix(matrix, temp);
587 Base::analyzePattern(temp);
595 grabMatrix(matrix, temp);
596 Base::factorize(temp);
603 m_iparm(IPARM_SYM) = API_SYM_YES;
604 m_iparm(IPARM_FACTORIZATION) = API_FACT_LLT;
607 void grabMatrix(
const MatrixType& matrix, ColSpMatrix& out)
610 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
611 internal::c_to_fortran_numbering(out);
629 template<
typename _MatrixType,
int _UpLo>
630 class PastixLDLT :
public PastixBase< PastixLDLT<_MatrixType, _UpLo> >
633 typedef _MatrixType MatrixType;
634 typedef PastixBase<PastixLDLT<MatrixType, _UpLo> > Base;
638 enum { UpLo = _UpLo };
656 grabMatrix(matrix, temp);
667 grabMatrix(matrix, temp);
668 Base::analyzePattern(temp);
676 grabMatrix(matrix, temp);
677 Base::factorize(temp);
685 m_iparm(IPARM_SYM) = API_SYM_YES;
686 m_iparm(IPARM_FACTORIZATION) = API_FACT_LDLT;
689 void grabMatrix(
const MatrixType& matrix, ColSpMatrix& out)
692 out.template selfadjointView<Lower>() = matrix.template selfadjointView<UpLo>();
693 internal::c_to_fortran_numbering(out);
699 template<
typename _MatrixType,
typename Rhs>
700 struct solve_retval<PastixBase<_MatrixType>, Rhs>
701 : solve_retval_base<PastixBase<_MatrixType>, Rhs>
703 typedef PastixBase<_MatrixType> Dec;
704 EIGEN_MAKE_SOLVE_HELPERS(Dec,Rhs)
706 template<
typename Dest>
void evalTo(Dest& dst)
const 708 dec()._solve(rhs(),dst);
712 template<
typename _MatrixType,
typename Rhs>
713 struct sparse_solve_retval<PastixBase<_MatrixType>, Rhs>
714 : sparse_solve_retval_base<PastixBase<_MatrixType>, Rhs>
716 typedef PastixBase<_MatrixType> Dec;
717 EIGEN_MAKE_SPARSE_SOLVE_HELPERS(Dec,Rhs)
719 template<
typename Dest>
void evalTo(Dest& dst)
const 721 this->defaultEvalTo(dst);
Index rows() const
Definition: SparseMatrix.h:119
Index cols() const
Definition: SparseMatrix.h:121
const Scalar * valuePtr() const
Definition: SparseMatrix.h:131
A versatible sparse matrix representation.
Definition: SparseMatrix.h:85
Definition: Constants.h:378
const Index * outerIndexPtr() const
Definition: SparseMatrix.h:149
Base class of any sparse matrices or sparse expressions.
Definition: ForwardDeclarations.h:239
A sparse direct supernodal Cholesky (LLT) factorization and solver based on the PaStiX library...
Definition: PaStiXSupport.h:33
void analyzePattern(const MatrixType &matrix)
Definition: PaStiXSupport.h:664
Derived & derived()
Definition: EigenBase.h:34
void analyzePattern(const MatrixType &matrix)
Definition: PaStiXSupport.h:583
void compute(const MatrixType &matrix)
Definition: PaStiXSupport.h:653
void factorize(const MatrixType &matrix)
Definition: PaStiXSupport.h:673
Definition: Constants.h:383
Definition: Eigen_Colamd.h:50
void analyzePattern(const MatrixType &matrix)
Definition: PaStiXSupport.h:475
General-purpose arrays with easy API for coefficient-wise operations.
Definition: Array.h:42
Definition: Constants.h:376
const unsigned int RowMajorBit
Definition: Constants.h:53
Interface to the PaStix solver.
Definition: PaStiXSupport.h:31
A sparse direct supernodal Cholesky (LLT) factorization and solver based on the PaStiX library...
Definition: PaStiXSupport.h:32
void factorize(const MatrixType &matrix)
Definition: PaStiXSupport.h:592
Index rows() const
Definition: SparseMatrixBase.h:160
ComputationInfo
Definition: Constants.h:374
void compute(const MatrixType &matrix)
Definition: PaStiXSupport.h:463
Base class for all dense matrices, vectors, and expressions.
Definition: MatrixBase.h:48
const Index * innerIndexPtr() const
Definition: SparseMatrix.h:140
void factorize(const MatrixType &matrix)
Definition: PaStiXSupport.h:488
void compute(const MatrixType &matrix)
Definition: PaStiXSupport.h:572
Derived & setZero(Index size)
Definition: CwiseNullaryOp.h:515