MLPACK  1.0.11
simple_tolerance_termination.hpp
Go to the documentation of this file.
1 
22 #ifndef _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
23 #define _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
24 
25 #include <mlpack/core.hpp>
26 
27 namespace mlpack {
28 namespace amf {
29 
40 template <class MatType>
42 {
43  public:
46  const size_t maxIterations = 10000,
47  const size_t reverseStepTolerance = 3)
51 
57  void Initialize(const MatType& V)
58  {
59  residueOld = DBL_MAX;
60  iteration = 1;
61  residue = DBL_MIN;
62  reverseStepCount = 0;
63  isCopy = false;
64 
65  this->V = &V;
66 
67  c_index = 0;
68  c_indexOld = 0;
69 
70  reverseStepCount = 0;
71  }
72 
79  bool IsConverged(arma::mat& W, arma::mat& H)
80  {
81  arma::mat WH;
82 
83  WH = W * H;
84 
85  // compute residue
87  size_t n = V->n_rows;
88  size_t m = V->n_cols;
89  double sum = 0;
90  size_t count = 0;
91  for(size_t i = 0;i < n;i++)
92  {
93  for(size_t j = 0;j < m;j++)
94  {
95  double temp = 0;
96  if((temp = (*V)(i,j)) != 0)
97  {
98  temp = (temp - WH(i, j));
99  temp = temp * temp;
100  sum += temp;
101  count++;
102  }
103  }
104  }
105  residue = sum / count;
106  residue = sqrt(residue);
107 
108  // increment iteration count
109  iteration++;
110 
111  // if residue tolerance is not satisfied
112  if ((residueOld - residue) / residueOld < tolerance && iteration > 4)
113  {
114  // check if this is a first of successive drops
115  if (reverseStepCount == 0 && isCopy == false)
116  {
117  // store a copy of W and H matrix
118  isCopy = true;
119  this->W = W;
120  this->H = H;
121  // store residue values
122  c_index = residue;
124  }
125  // increase successive drop count
127  }
128  // if tolerance is satisfied
129  else
130  {
131  // initialize successive drop count
132  reverseStepCount = 0;
133  // if residue is droped below minimum scrap stored values
134  if(residue <= c_indexOld && isCopy == true)
135  {
136  isCopy = false;
137  }
138  }
139 
140  // check if termination criterion is met
142  {
143  // if stored values are present replace them with current value as they
144  // represent the minimum residue point
145  if(isCopy)
146  {
147  W = this->W;
148  H = this->H;
149  residue = c_index;
150  }
151  return true;
152  }
153  else return false;
154  }
155 
157  const double& Index() const { return residue; }
158 
160  const size_t& Iteration() const { return iteration; }
161 
163  const size_t& MaxIterations() const { return maxIterations; }
164  size_t& MaxIterations() { return maxIterations; }
165 
167  const double& Tolerance() const { return tolerance; }
168  double& Tolerance() { return tolerance; }
169 
170  private:
172  double tolerance;
175 
177  const MatType* V;
178 
180  size_t iteration;
181 
183  double residueOld;
184  double residue;
185  double normOld;
186 
191 
194  bool isCopy;
195 
197  arma::mat W;
198  arma::mat H;
199  double c_indexOld;
200  double c_index;
201 }; // class SimpleToleranceTermination
202 
203 }; // namespace amf
204 }; // namespace mlpack
205 
206 #endif // _MLPACK_METHODS_AMF_SIMPLE_TOLERANCE_TERMINATION_HPP_INCLUDED
207 
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: load.hpp:31
const size_t & Iteration() const
Get current iteration count.
SimpleToleranceTermination(const double tolerance=1e-5, const size_t maxIterations=10000, const size_t reverseStepTolerance=3)
empty constructor
const MatType * V
pointer to matrix being factorized
bool isCopy
indicates whether a copy of information is available which corresponds to minimum residue point ...
arma::mat W
variables to store information of minimum residue poi
size_t reverseStepTolerance
tolerance on successive residue drops
const double & Index() const
Get current value of residue.
This class implements residue tolerance termination policy.
void Initialize(const MatType &V)
Initializes the termination policy before stating the factorization.
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.
const double & Tolerance() const
Access tolerance value.
const size_t & MaxIterations() const
Access upper limit of iteration count.