MatrixSquareRoot.h
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2011 Jitse Niesen <jitse@maths.leeds.ac.uk>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #ifndef EIGEN_MATRIX_SQUARE_ROOT
11 #define EIGEN_MATRIX_SQUARE_ROOT
12 
13 namespace Eigen {
14 
26 template <typename MatrixType>
28 {
29  public:
30 
39  MatrixSquareRootQuasiTriangular(const MatrixType& A)
40  : m_A(A)
41  {
42  eigen_assert(A.rows() == A.cols());
43  }
44 
53  template <typename ResultType> void compute(ResultType &result);
54 
55  private:
56  typedef typename MatrixType::Index Index;
57  typedef typename MatrixType::Scalar Scalar;
58 
59  void computeDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
60  void computeOffDiagonalPartOfSqrt(MatrixType& sqrtT, const MatrixType& T);
61  void compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i);
62  void compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
63  typename MatrixType::Index i, typename MatrixType::Index j);
64  void compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
65  typename MatrixType::Index i, typename MatrixType::Index j);
66  void compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
67  typename MatrixType::Index i, typename MatrixType::Index j);
68  void compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
69  typename MatrixType::Index i, typename MatrixType::Index j);
70 
71  template <typename SmallMatrixType>
72  static void solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
73  const SmallMatrixType& B, const SmallMatrixType& C);
74 
75  const MatrixType& m_A;
76 };
77 
78 template <typename MatrixType>
79 template <typename ResultType>
81 {
82  // Compute Schur decomposition of m_A
83  const RealSchur<MatrixType> schurOfA(m_A);
84  const MatrixType& T = schurOfA.matrixT();
85  const MatrixType& U = schurOfA.matrixU();
86 
87  // Compute square root of T
88  MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
89  computeDiagonalPartOfSqrt(sqrtT, T);
90  computeOffDiagonalPartOfSqrt(sqrtT, T);
91 
92  // Compute square root of m_A
93  result = U * sqrtT * U.adjoint();
94 }
95 
96 // pre: T is quasi-upper-triangular and sqrtT is a zero matrix of the same size
97 // post: the diagonal blocks of sqrtT are the square roots of the diagonal blocks of T
98 template <typename MatrixType>
100  const MatrixType& T)
101 {
102  const Index size = m_A.rows();
103  for (Index i = 0; i < size; i++) {
104  if (i == size - 1 || T.coeff(i+1, i) == 0) {
105  eigen_assert(T(i,i) > 0);
106  sqrtT.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
107  }
108  else {
109  compute2x2diagonalBlock(sqrtT, T, i);
110  ++i;
111  }
112  }
113 }
114 
115 // pre: T is quasi-upper-triangular and diagonal blocks of sqrtT are square root of diagonal blocks of T.
116 // post: sqrtT is the square root of T.
117 template <typename MatrixType>
118 void MatrixSquareRootQuasiTriangular<MatrixType>::computeOffDiagonalPartOfSqrt(MatrixType& sqrtT,
119  const MatrixType& T)
120 {
121  const Index size = m_A.rows();
122  for (Index j = 1; j < size; j++) {
123  if (T.coeff(j, j-1) != 0) // if T(j-1:j, j-1:j) is a 2-by-2 block
124  continue;
125  for (Index i = j-1; i >= 0; i--) {
126  if (i > 0 && T.coeff(i, i-1) != 0) // if T(i-1:i, i-1:i) is a 2-by-2 block
127  continue;
128  bool iBlockIs2x2 = (i < size - 1) && (T.coeff(i+1, i) != 0);
129  bool jBlockIs2x2 = (j < size - 1) && (T.coeff(j+1, j) != 0);
130  if (iBlockIs2x2 && jBlockIs2x2)
131  compute2x2offDiagonalBlock(sqrtT, T, i, j);
132  else if (iBlockIs2x2 && !jBlockIs2x2)
133  compute2x1offDiagonalBlock(sqrtT, T, i, j);
134  else if (!iBlockIs2x2 && jBlockIs2x2)
135  compute1x2offDiagonalBlock(sqrtT, T, i, j);
136  else if (!iBlockIs2x2 && !jBlockIs2x2)
137  compute1x1offDiagonalBlock(sqrtT, T, i, j);
138  }
139  }
140 }
141 
142 // pre: T.block(i,i,2,2) has complex conjugate eigenvalues
143 // post: sqrtT.block(i,i,2,2) is square root of T.block(i,i,2,2)
144 template <typename MatrixType>
145 void MatrixSquareRootQuasiTriangular<MatrixType>
146  ::compute2x2diagonalBlock(MatrixType& sqrtT, const MatrixType& T, typename MatrixType::Index i)
147 {
148  // TODO: This case (2-by-2 blocks with complex conjugate eigenvalues) is probably hidden somewhere
149  // in EigenSolver. If we expose it, we could call it directly from here.
150  Matrix<Scalar,2,2> block = T.template block<2,2>(i,i);
151  EigenSolver<Matrix<Scalar,2,2> > es(block);
152  sqrtT.template block<2,2>(i,i)
153  = (es.eigenvectors() * es.eigenvalues().cwiseSqrt().asDiagonal() * es.eigenvectors().inverse()).real();
154 }
155 
156 // pre: block structure of T is such that (i,j) is a 1x1 block,
157 // all blocks of sqrtT to left of and below (i,j) are correct
158 // post: sqrtT(i,j) has the correct value
159 template <typename MatrixType>
160 void MatrixSquareRootQuasiTriangular<MatrixType>
161  ::compute1x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
162  typename MatrixType::Index i, typename MatrixType::Index j)
163 {
164  Scalar tmp = (sqrtT.row(i).segment(i+1,j-i-1) * sqrtT.col(j).segment(i+1,j-i-1)).value();
165  sqrtT.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (sqrtT.coeff(i,i) + sqrtT.coeff(j,j));
166 }
167 
168 // similar to compute1x1offDiagonalBlock()
169 template <typename MatrixType>
170 void MatrixSquareRootQuasiTriangular<MatrixType>
171  ::compute1x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
172  typename MatrixType::Index i, typename MatrixType::Index j)
173 {
174  Matrix<Scalar,1,2> rhs = T.template block<1,2>(i,j);
175  if (j-i > 1)
176  rhs -= sqrtT.block(i, i+1, 1, j-i-1) * sqrtT.block(i+1, j, j-i-1, 2);
177  Matrix<Scalar,2,2> A = sqrtT.coeff(i,i) * Matrix<Scalar,2,2>::Identity();
178  A += sqrtT.template block<2,2>(j,j).transpose();
179  sqrtT.template block<1,2>(i,j).transpose() = A.fullPivLu().solve(rhs.transpose());
180 }
181 
182 // similar to compute1x1offDiagonalBlock()
183 template <typename MatrixType>
184 void MatrixSquareRootQuasiTriangular<MatrixType>
185  ::compute2x1offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
186  typename MatrixType::Index i, typename MatrixType::Index j)
187 {
188  Matrix<Scalar,2,1> rhs = T.template block<2,1>(i,j);
189  if (j-i > 2)
190  rhs -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 1);
191  Matrix<Scalar,2,2> A = sqrtT.coeff(j,j) * Matrix<Scalar,2,2>::Identity();
192  A += sqrtT.template block<2,2>(i,i);
193  sqrtT.template block<2,1>(i,j) = A.fullPivLu().solve(rhs);
194 }
195 
196 // similar to compute1x1offDiagonalBlock()
197 template <typename MatrixType>
198 void MatrixSquareRootQuasiTriangular<MatrixType>
199  ::compute2x2offDiagonalBlock(MatrixType& sqrtT, const MatrixType& T,
200  typename MatrixType::Index i, typename MatrixType::Index j)
201 {
202  Matrix<Scalar,2,2> A = sqrtT.template block<2,2>(i,i);
203  Matrix<Scalar,2,2> B = sqrtT.template block<2,2>(j,j);
204  Matrix<Scalar,2,2> C = T.template block<2,2>(i,j);
205  if (j-i > 2)
206  C -= sqrtT.block(i, i+2, 2, j-i-2) * sqrtT.block(i+2, j, j-i-2, 2);
207  Matrix<Scalar,2,2> X;
208  solveAuxiliaryEquation(X, A, B, C);
209  sqrtT.template block<2,2>(i,j) = X;
210 }
211 
212 // solves the equation A X + X B = C where all matrices are 2-by-2
213 template <typename MatrixType>
214 template <typename SmallMatrixType>
215 void MatrixSquareRootQuasiTriangular<MatrixType>
216  ::solveAuxiliaryEquation(SmallMatrixType& X, const SmallMatrixType& A,
217  const SmallMatrixType& B, const SmallMatrixType& C)
218 {
219  EIGEN_STATIC_ASSERT((internal::is_same<SmallMatrixType, Matrix<Scalar,2,2> >::value),
220  EIGEN_INTERNAL_ERROR_PLEASE_FILE_A_BUG_REPORT);
221 
222  Matrix<Scalar,4,4> coeffMatrix = Matrix<Scalar,4,4>::Zero();
223  coeffMatrix.coeffRef(0,0) = A.coeff(0,0) + B.coeff(0,0);
224  coeffMatrix.coeffRef(1,1) = A.coeff(0,0) + B.coeff(1,1);
225  coeffMatrix.coeffRef(2,2) = A.coeff(1,1) + B.coeff(0,0);
226  coeffMatrix.coeffRef(3,3) = A.coeff(1,1) + B.coeff(1,1);
227  coeffMatrix.coeffRef(0,1) = B.coeff(1,0);
228  coeffMatrix.coeffRef(0,2) = A.coeff(0,1);
229  coeffMatrix.coeffRef(1,0) = B.coeff(0,1);
230  coeffMatrix.coeffRef(1,3) = A.coeff(0,1);
231  coeffMatrix.coeffRef(2,0) = A.coeff(1,0);
232  coeffMatrix.coeffRef(2,3) = B.coeff(1,0);
233  coeffMatrix.coeffRef(3,1) = A.coeff(1,0);
234  coeffMatrix.coeffRef(3,2) = B.coeff(0,1);
235 
236  Matrix<Scalar,4,1> rhs;
237  rhs.coeffRef(0) = C.coeff(0,0);
238  rhs.coeffRef(1) = C.coeff(0,1);
239  rhs.coeffRef(2) = C.coeff(1,0);
240  rhs.coeffRef(3) = C.coeff(1,1);
241 
242  Matrix<Scalar,4,1> result;
243  result = coeffMatrix.fullPivLu().solve(rhs);
244 
245  X.coeffRef(0,0) = result.coeff(0);
246  X.coeffRef(0,1) = result.coeff(1);
247  X.coeffRef(1,0) = result.coeff(2);
248  X.coeffRef(1,1) = result.coeff(3);
249 }
250 
251 
263 template <typename MatrixType>
265 {
266  public:
267  MatrixSquareRootTriangular(const MatrixType& A)
268  : m_A(A)
269  {
270  eigen_assert(A.rows() == A.cols());
271  }
272 
282  template <typename ResultType> void compute(ResultType &result);
283 
284  private:
285  const MatrixType& m_A;
286 };
287 
288 template <typename MatrixType>
289 template <typename ResultType>
291 {
292  // Compute Schur decomposition of m_A
293  const ComplexSchur<MatrixType> schurOfA(m_A);
294  const MatrixType& T = schurOfA.matrixT();
295  const MatrixType& U = schurOfA.matrixU();
296 
297  // Compute square root of T and store it in upper triangular part of result
298  // This uses that the square root of triangular matrices can be computed directly.
299  result.resize(m_A.rows(), m_A.cols());
300  typedef typename MatrixType::Index Index;
301  for (Index i = 0; i < m_A.rows(); i++) {
302  result.coeffRef(i,i) = internal::sqrt(T.coeff(i,i));
303  }
304  for (Index j = 1; j < m_A.cols(); j++) {
305  for (Index i = j-1; i >= 0; i--) {
306  typedef typename MatrixType::Scalar Scalar;
307  // if i = j-1, then segment has length 0 so tmp = 0
308  Scalar tmp = (result.row(i).segment(i+1,j-i-1) * result.col(j).segment(i+1,j-i-1)).value();
309  // denominator may be zero if original matrix is singular
310  result.coeffRef(i,j) = (T.coeff(i,j) - tmp) / (result.coeff(i,i) + result.coeff(j,j));
311  }
312  }
313 
314  // Compute square root of m_A as U * result * U.adjoint()
315  MatrixType tmp;
316  tmp.noalias() = U * result.template triangularView<Upper>();
317  result.noalias() = tmp * U.adjoint();
318 }
319 
320 
328 template <typename MatrixType, int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
330 {
331  public:
332 
340  MatrixSquareRoot(const MatrixType& A);
341 
349  template <typename ResultType> void compute(ResultType &result);
350 };
351 
352 
353 // ********** Partial specialization for real matrices **********
354 
355 template <typename MatrixType>
356 class MatrixSquareRoot<MatrixType, 0>
357 {
358  public:
359 
360  MatrixSquareRoot(const MatrixType& A)
361  : m_A(A)
362  {
363  eigen_assert(A.rows() == A.cols());
364  }
365 
366  template <typename ResultType> void compute(ResultType &result)
367  {
368  // Compute Schur decomposition of m_A
369  const RealSchur<MatrixType> schurOfA(m_A);
370  const MatrixType& T = schurOfA.matrixT();
371  const MatrixType& U = schurOfA.matrixU();
372 
373  // Compute square root of T
374  MatrixSquareRootQuasiTriangular<MatrixType> tmp(T);
375  MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
376  tmp.compute(sqrtT);
377 
378  // Compute square root of m_A
379  result = U * sqrtT * U.adjoint();
380  }
381 
382  private:
383  const MatrixType& m_A;
384 };
385 
386 
387 // ********** Partial specialization for complex matrices **********
388 
389 template <typename MatrixType>
390 class MatrixSquareRoot<MatrixType, 1>
391 {
392  public:
393 
394  MatrixSquareRoot(const MatrixType& A)
395  : m_A(A)
396  {
397  eigen_assert(A.rows() == A.cols());
398  }
399 
400  template <typename ResultType> void compute(ResultType &result)
401  {
402  // Compute Schur decomposition of m_A
403  const ComplexSchur<MatrixType> schurOfA(m_A);
404  const MatrixType& T = schurOfA.matrixT();
405  const MatrixType& U = schurOfA.matrixU();
406 
407  // Compute square root of T
408  MatrixSquareRootTriangular<MatrixType> tmp(T);
409  MatrixType sqrtT = MatrixType::Zero(m_A.rows(), m_A.rows());
410  tmp.compute(sqrtT);
411 
412  // Compute square root of m_A
413  result = U * sqrtT * U.adjoint();
414  }
415 
416  private:
417  const MatrixType& m_A;
418 };
419 
420 
433 template<typename Derived> class MatrixSquareRootReturnValue
434 : public ReturnByValue<MatrixSquareRootReturnValue<Derived> >
435 {
436  typedef typename Derived::Index Index;
437  public:
443  MatrixSquareRootReturnValue(const Derived& src) : m_src(src) { }
444 
450  template <typename ResultType>
451  inline void evalTo(ResultType& result) const
452  {
453  const typename Derived::PlainObject srcEvaluated = m_src.eval();
455  me.compute(result);
456  }
457 
458  Index rows() const { return m_src.rows(); }
459  Index cols() const { return m_src.cols(); }
460 
461  protected:
462  const Derived& m_src;
463  private:
465 };
466 
467 namespace internal {
468 template<typename Derived>
469 struct traits<MatrixSquareRootReturnValue<Derived> >
470 {
471  typedef typename Derived::PlainObject ReturnType;
472 };
473 }
474 
475 template <typename Derived>
476 const MatrixSquareRootReturnValue<Derived> MatrixBase<Derived>::sqrt() const
477 {
478  eigen_assert(rows() == cols());
479  return MatrixSquareRootReturnValue<Derived>(derived());
480 }
481 
482 } // end namespace Eigen
483 
484 #endif // EIGEN_MATRIX_FUNCTION